KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
3 changed files with 1 additions and 55 deletions
Showing only changes of commit bb4894b87e - Show all commits

View File

@ -151,14 +151,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Alya") TODO("Alya")
} }
override fun RealTensor.dotAssign(other: RealTensor) {
TODO("Alya")
}
override fun RealTensor.dotRightAssign(other: RealTensor) {
TODO("Alya")
}
override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor { override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor {
TODO("Alya") TODO("Alya")
} }
@ -183,15 +175,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
return resTensor return resTensor
} }
override fun RealTensor.transposeAssign(i: Int, j: Int) {
val transposedTensor = this.transpose(i, j)
for (i in transposedTensor.shape.indices) {
this.shape[i] = transposedTensor.shape[i]
}
for (i in transposedTensor.buffer.array.indices) {
this.buffer.array[i] = transposedTensor.buffer.array[i]
}
}
override fun RealTensor.view(shape: IntArray): RealTensor { override fun RealTensor.view(shape: IntArray): RealTensor {
return RealTensor(shape, this.buffer.array) return RealTensor(shape, this.buffer.array)
@ -205,17 +188,12 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun RealTensor.absAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.sum(): RealTensor { override fun RealTensor.sum(): RealTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun RealTensor.sumAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.div(value: Double): RealTensor { override fun RealTensor.div(value: Double): RealTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
@ -237,17 +215,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun RealTensor.expAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.log(): RealTensor { override fun RealTensor.log(): RealTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun RealTensor.logAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> { override fun RealTensor.lu(): Pair<RealTensor, RealTensor> {
TODO() TODO()

View File

@ -35,8 +35,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.matmul.html //https://pytorch.org/docs/stable/generated/torch.matmul.html
public infix fun TensorType.dot(other: TensorType): TensorType public infix fun TensorType.dot(other: TensorType): TensorType
public infix fun TensorType.dotAssign(other: TensorType): Unit
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html //https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding( public fun diagonalEmbedding(
@ -46,7 +44,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.transpose.html //https://pytorch.org/docs/stable/generated/torch.transpose.html
public fun TensorType.transpose(i: Int, j: Int): TensorType public fun TensorType.transpose(i: Int, j: Int): TensorType
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
//https://pytorch.org/docs/stable/tensor_view.html //https://pytorch.org/docs/stable/tensor_view.html
public fun TensorType.view(shape: IntArray): TensorType public fun TensorType.view(shape: IntArray): TensorType
@ -54,11 +51,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.abs.html //https://pytorch.org/docs/stable/generated/torch.abs.html
public fun TensorType.abs(): TensorType public fun TensorType.abs(): TensorType
public fun TensorType.absAssign(): Unit
//https://pytorch.org/docs/stable/generated/torch.sum.html //https://pytorch.org/docs/stable/generated/torch.sum.html
public fun TensorType.sum(): TensorType public fun TensorType.sum(): TensorType
public fun TensorType.sumAssign(): Unit
} }
// https://proofwiki.org/wiki/Definition:Division_Algebra // https://proofwiki.org/wiki/Definition:Division_Algebra
@ -72,11 +67,9 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
//https://pytorch.org/docs/stable/generated/torch.exp.html //https://pytorch.org/docs/stable/generated/torch.exp.html
public fun TensorType.exp(): TensorType public fun TensorType.exp(): TensorType
public fun TensorType.expAssign(): Unit
//https://pytorch.org/docs/stable/generated/torch.log.html //https://pytorch.org/docs/stable/generated/torch.log.html
public fun TensorType.log(): TensorType public fun TensorType.log(): TensorType
public fun TensorType.logAssign(): Unit
//https://pytorch.org/docs/stable/generated/torch.lu.html //https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, TensorType> public fun TensorType.lu(): Pair<TensorType, TensorType>

View File

@ -46,22 +46,4 @@ class TestRealTensorAlgebra {
assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test
fun transposeAssign1x2() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0))
tensor.transposeAssign(0, 1)
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0))
assertTrue(tensor.shape contentEquals intArrayOf(2, 1))
}
@Test
fun transposeAssign2x3() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
tensor.transposeAssign(1, 0)
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
}
} }