KMP library for tensors #300
@ -151,14 +151,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
TODO("Alya")
|
TODO("Alya")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.dotAssign(other: RealTensor) {
|
|
||||||
TODO("Alya")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun RealTensor.dotRightAssign(other: RealTensor) {
|
|
||||||
TODO("Alya")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor {
|
override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor {
|
||||||
TODO("Alya")
|
TODO("Alya")
|
||||||
}
|
}
|
||||||
@ -183,15 +175,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.transposeAssign(i: Int, j: Int) {
|
|
||||||
val transposedTensor = this.transpose(i, j)
|
|
||||||
for (i in transposedTensor.shape.indices) {
|
|
||||||
this.shape[i] = transposedTensor.shape[i]
|
|
||||||
}
|
|
||||||
for (i in transposedTensor.buffer.array.indices) {
|
|
||||||
this.buffer.array[i] = transposedTensor.buffer.array[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun RealTensor.view(shape: IntArray): RealTensor {
|
override fun RealTensor.view(shape: IntArray): RealTensor {
|
||||||
return RealTensor(shape, this.buffer.array)
|
return RealTensor(shape, this.buffer.array)
|
||||||
@ -205,17 +188,12 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.absAssign() {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun RealTensor.sum(): RealTensor {
|
override fun RealTensor.sum(): RealTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.sumAssign() {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun RealTensor.div(value: Double): RealTensor {
|
override fun RealTensor.div(value: Double): RealTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
@ -237,17 +215,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.expAssign() {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun RealTensor.log(): RealTensor {
|
override fun RealTensor.log(): RealTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.logAssign() {
|
|
||||||
TODO("Not yet implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> {
|
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> {
|
||||||
TODO()
|
TODO()
|
||||||
|
@ -35,8 +35,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||||
public infix fun TensorType.dotAssign(other: TensorType): Unit
|
|
||||||
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||||
public fun diagonalEmbedding(
|
public fun diagonalEmbedding(
|
||||||
@ -46,7 +44,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||||
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/tensor_view.html
|
//https://pytorch.org/docs/stable/tensor_view.html
|
||||||
public fun TensorType.view(shape: IntArray): TensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
@ -54,11 +51,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
||||||
public fun TensorType.abs(): TensorType
|
public fun TensorType.abs(): TensorType
|
||||||
public fun TensorType.absAssign(): Unit
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.sum.html
|
//https://pytorch.org/docs/stable/generated/torch.sum.html
|
||||||
public fun TensorType.sum(): TensorType
|
public fun TensorType.sum(): TensorType
|
||||||
public fun TensorType.sumAssign(): Unit
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||||
@ -72,11 +67,9 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
|||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
//https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||||
public fun TensorType.exp(): TensorType
|
public fun TensorType.exp(): TensorType
|
||||||
public fun TensorType.expAssign(): Unit
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.log.html
|
//https://pytorch.org/docs/stable/generated/torch.log.html
|
||||||
public fun TensorType.log(): TensorType
|
public fun TensorType.log(): TensorType
|
||||||
public fun TensorType.logAssign(): Unit
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||||
public fun TensorType.lu(): Pair<TensorType, TensorType>
|
public fun TensorType.lu(): Pair<TensorType, TensorType>
|
||||||
|
@ -46,22 +46,4 @@ class TestRealTensorAlgebra {
|
|||||||
assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
fun transposeAssign1x2() = RealTensorAlgebra {
|
|
||||||
val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0))
|
|
||||||
tensor.transposeAssign(0, 1)
|
|
||||||
|
|
||||||
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0))
|
|
||||||
assertTrue(tensor.shape contentEquals intArrayOf(2, 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun transposeAssign2x3() = RealTensorAlgebra {
|
|
||||||
val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
|
||||||
tensor.transposeAssign(1, 0)
|
|
||||||
|
|
||||||
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
|
||||||
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
|
|
||||||
}
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user