1D mutable structure setter fixed
This commit is contained in:
parent
0553a28ee8
commit
70bebbe848
@ -52,7 +52,7 @@ private inline class MutableStructure1DWrapper<T>(val structure: MutableNDStruct
|
||||
|
||||
override fun get(index: Int): T = structure[index]
|
||||
override fun set(index: Int, value: T) {
|
||||
set(index, value)
|
||||
structure[intArrayOf(index)] = value
|
||||
}
|
||||
|
||||
override fun copy(): MutableBuffer<T> =
|
||||
|
@ -11,9 +11,6 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
|
||||
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
||||
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||
public fun TensorType.exp(): TensorType
|
||||
|
||||
@ -23,9 +20,6 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
||||
//https://pytorch.org/docs/stable/generated/torch.sqrt.html
|
||||
public fun TensorType.sqrt(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.square.html
|
||||
public fun TensorType.square(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
|
||||
public fun TensorType.cos(): TensorType
|
||||
|
||||
|
@ -16,10 +16,6 @@ public class DoubleAnalyticTensorAlgebra:
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.square(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cos(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
@ -152,10 +148,6 @@ public class DoubleAnalyticTensorAlgebra:
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
@ -254,6 +254,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.square(): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
||||
|
@ -40,6 +40,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||
public operator fun TensorType.unaryMinus(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.square.html
|
||||
public fun TensorType.square(): TensorType
|
||||
|
||||
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||
public operator fun TensorType.get(i: Int): TensorType
|
||||
|
||||
|
@ -10,4 +10,7 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
||||
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
||||
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ class TestDoubleTensor {
|
||||
matrix[0,1] = 77.89
|
||||
assertEquals(tensor[intArrayOf(0,0,1)], 77.89)
|
||||
|
||||
//vector[0] = 109.56
|
||||
//println(tensor[intArrayOf(0,1,0)])
|
||||
vector[0] = 109.56
|
||||
assertEquals(tensor[intArrayOf(0,1,0)], 109.56)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user