KMP library for tensors #300
@ -52,7 +52,7 @@ private inline class MutableStructure1DWrapper<T>(val structure: MutableNDStruct
|
|||||||
|
|
||||||
override fun get(index: Int): T = structure[index]
|
override fun get(index: Int): T = structure[index]
|
||||||
override fun set(index: Int, value: T) {
|
override fun set(index: Int, value: T) {
|
||||||
set(index, value)
|
structure[intArrayOf(index)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<T> =
|
override fun copy(): MutableBuffer<T> =
|
||||||
|
@ -11,9 +11,6 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
|
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
|
||||||
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
|
||||||
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||||
public fun TensorType.exp(): TensorType
|
public fun TensorType.exp(): TensorType
|
||||||
|
|
||||||
@ -23,9 +20,6 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.sqrt.html
|
//https://pytorch.org/docs/stable/generated/torch.sqrt.html
|
||||||
public fun TensorType.sqrt(): TensorType
|
public fun TensorType.sqrt(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.square.html
|
|
||||||
public fun TensorType.square(): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
|
//https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
|
||||||
public fun TensorType.cos(): TensorType
|
public fun TensorType.cos(): TensorType
|
||||||
|
|
||||||
|
@ -16,10 +16,6 @@ public class DoubleAnalyticTensorAlgebra:
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.square(): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.cos(): DoubleTensor {
|
override fun DoubleTensor.cos(): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
@ -152,10 +148,6 @@ public class DoubleAnalyticTensorAlgebra:
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
|
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
@ -254,6 +254,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.square(): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =
|
||||||
|
@ -40,6 +40,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||||
public operator fun TensorType.unaryMinus(): TensorType
|
public operator fun TensorType.unaryMinus(): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.square.html
|
||||||
|
public fun TensorType.square(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||||
public operator fun TensorType.get(i: Int): TensorType
|
public operator fun TensorType.get(i: Int): TensorType
|
||||||
|
|
||||||
|
@ -10,4 +10,7 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
|||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
|
||||||
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
|
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
|
||||||
|
|
||||||
|
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
|
||||||
|
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ class TestDoubleTensor {
|
|||||||
matrix[0,1] = 77.89
|
matrix[0,1] = 77.89
|
||||||
assertEquals(tensor[intArrayOf(0,0,1)], 77.89)
|
assertEquals(tensor[intArrayOf(0,0,1)], 77.89)
|
||||||
|
|
||||||
//vector[0] = 109.56
|
vector[0] = 109.56
|
||||||
//println(tensor[intArrayOf(0,1,0)])
|
assertEquals(tensor[intArrayOf(0,1,0)], 109.56)
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user