KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
7 changed files with 17 additions and 17 deletions
Showing only changes of commit 70bebbe848 - Show all commits

View File

@ -52,7 +52,7 @@ private inline class MutableStructure1DWrapper<T>(val structure: MutableNDStruct
override fun get(index: Int): T = structure[index] override fun get(index: Int): T = structure[index]
override fun set(index: Int, value: T) { override fun set(index: Int, value: T) {
set(index, value) structure[intArrayOf(index)] = value
} }
override fun copy(): MutableBuffer<T> = override fun copy(): MutableBuffer<T> =

View File

@ -11,9 +11,6 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
//https://pytorch.org/docs/stable/generated/torch.std.html#torch.std //https://pytorch.org/docs/stable/generated/torch.std.html#torch.std
public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType public fun TensorType.std(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.exp.html //https://pytorch.org/docs/stable/generated/torch.exp.html
public fun TensorType.exp(): TensorType public fun TensorType.exp(): TensorType
@ -23,9 +20,6 @@ public interface AnalyticTensorAlgebra<T, TensorType : TensorStructure<T>> :
//https://pytorch.org/docs/stable/generated/torch.sqrt.html //https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun TensorType.sqrt(): TensorType public fun TensorType.sqrt(): TensorType
//https://pytorch.org/docs/stable/generated/torch.square.html
public fun TensorType.square(): TensorType
//https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos //https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun TensorType.cos(): TensorType public fun TensorType.cos(): TensorType

View File

@ -16,10 +16,6 @@ public class DoubleAnalyticTensorAlgebra:
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun DoubleTensor.square(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.cos(): DoubleTensor { override fun DoubleTensor.cos(): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
@ -152,10 +148,6 @@ public class DoubleAnalyticTensorAlgebra:
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor { override fun DoubleTensor.histc(bins: Int, min: Double, max: Double): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }

View File

@ -254,6 +254,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun DoubleTensor.square(): DoubleTensor {
TODO("Not yet implemented")
}
override fun DoubleTensor.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): DoubleTensor {
TODO("Not yet implemented")
}
} }
public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R = public inline fun <R> DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R =

View File

@ -40,6 +40,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public operator fun TensorType.timesAssign(other: TensorType): Unit public operator fun TensorType.timesAssign(other: TensorType): Unit
public operator fun TensorType.unaryMinus(): TensorType public operator fun TensorType.unaryMinus(): TensorType
//https://pytorch.org/docs/stable/generated/torch.square.html
public fun TensorType.square(): TensorType
//https://pytorch.org/cppdocs/notes/tensor_indexing.html //https://pytorch.org/cppdocs/notes/tensor_indexing.html
public operator fun TensorType.get(i: Int): TensorType public operator fun TensorType.get(i: Int): TensorType

View File

@ -10,4 +10,7 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
//https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean //https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean
public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType public fun TensorType.mean(dim: Int, keepDim: Boolean): TensorType
//https://pytorch.org/docs/stable/generated/torch.var.html#torch.var
public fun TensorType.variance(dim: Int, unbiased: Boolean, keepDim: Boolean): TensorType
} }

View File

@ -36,7 +36,7 @@ class TestDoubleTensor {
matrix[0,1] = 77.89 matrix[0,1] = 77.89
assertEquals(tensor[intArrayOf(0,0,1)], 77.89) assertEquals(tensor[intArrayOf(0,0,1)], 77.89)
//vector[0] = 109.56 vector[0] = 109.56
//println(tensor[intArrayOf(0,1,0)]) assertEquals(tensor[intArrayOf(0,1,0)], 109.56)
} }
} }