diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt index 3e21a3a98..9e040fcd8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt @@ -3,6 +3,7 @@ package space.kscience.kmath.tensors import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.array import kotlin.math.abs +import kotlin.math.max public class RealTensor( @@ -50,27 +51,85 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra { + val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray()) + val n = totalShape.reduce{ acc, i -> acc * i } + + val res = ArrayList(0) + for (tensor in tensors) { + val resTensor = RealTensor(totalShape, DoubleArray(n)) + + for (linearIndex in 0 until n) { + val totalMultiIndex = resTensor.strides.index(linearIndex) + val curMultiIndex = tensor.shape.copyOf() + + val offset = totalMultiIndex.size - curMultiIndex.size + + for (i in curMultiIndex.indices) { + if (curMultiIndex[i] != 1) { + curMultiIndex[i] = totalMultiIndex[i + offset] + } else { + curMultiIndex[i] = 0 + } + } + + val curLinearIndex = tensor.strides.offset(curMultiIndex) + resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex] + } + res.add(resTensor) + } + + return res + } + override fun Double.plus(other: RealTensor): RealTensor { - //todo should be change with broadcasting val resBuffer = DoubleArray(other.buffer.size) { i -> other.buffer.array[i] + this } return RealTensor(other.shape, resBuffer) } - //todo should be change with broadcasting override fun RealTensor.plus(value: Double): RealTensor = value + this override fun RealTensor.plus(other: RealTensor): RealTensor { - //todo should be change with broadcasting - val resBuffer = DoubleArray(this.buffer.size) { i -> - this.buffer.array[i] + other.buffer.array[i] + val broadcast = broadcastTensors(this, other) + val newThis = broadcast[0] + val newOther = broadcast[1] + val resBuffer = DoubleArray(newThis.buffer.size) { i -> + newThis.buffer.array[i] + newOther.buffer.array[i] } - return RealTensor(this.shape, resBuffer) + return RealTensor(newThis.shape, resBuffer) } override fun RealTensor.plusAssign(value: Double) { - //todo should be change with broadcasting for (i in this.buffer.array.indices) { this.buffer.array[i] += value } @@ -84,19 +143,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra + this - other.buffer.array[i] + } + return RealTensor(other.shape, resBuffer) } override fun RealTensor.minus(value: Double): RealTensor { - TODO("Alya") + val resBuffer = DoubleArray(this.buffer.size) { i -> + this.buffer.array[i] - value + } + return RealTensor(this.shape, resBuffer) } override fun RealTensor.minus(other: RealTensor): RealTensor { - TODO("Alya") + val broadcast = broadcastTensors(this, other) + val newThis = broadcast[0] + val newOther = broadcast[1] + val resBuffer = DoubleArray(newThis.buffer.size) { i -> + newThis.buffer.array[i] - newOther.buffer.array[i] + } + return RealTensor(newThis.shape, resBuffer) } override fun RealTensor.minusAssign(value: Double) { - TODO("Alya") + for (i in this.buffer.array.indices) { + this.buffer.array[i] -= value + } } override fun RealTensor.minusAssign(other: RealTensor) { @@ -147,19 +220,12 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra { // todo checks diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 0dbcae044..ec257e45c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -13,6 +13,9 @@ public interface TensorAlgebra> { public fun TensorType.copy(): TensorType + public fun broadcastShapes(vararg shapes: IntArray): IntArray + public fun broadcastTensors(vararg tensors: RealTensor): List + public operator fun T.plus(other: TensorType): TensorType public operator fun TensorType.plus(value: T): TensorType public operator fun TensorType.plus(other: TensorType): TensorType @@ -35,8 +38,6 @@ public interface TensorAlgebra> { //https://pytorch.org/docs/stable/generated/torch.matmul.html public infix fun TensorType.dot(other: TensorType): TensorType - public infix fun TensorType.dotAssign(other: TensorType): Unit - public infix fun TensorType.dotRightAssign(other: TensorType): Unit //https://pytorch.org/docs/stable/generated/torch.diag_embed.html public fun diagonalEmbedding( @@ -46,7 +47,6 @@ public interface TensorAlgebra> { //https://pytorch.org/docs/stable/generated/torch.transpose.html public fun TensorType.transpose(i: Int, j: Int): TensorType - public fun TensorType.transposeAssign(i: Int, j: Int): Unit //https://pytorch.org/docs/stable/tensor_view.html public fun TensorType.view(shape: IntArray): TensorType @@ -54,11 +54,9 @@ public interface TensorAlgebra> { //https://pytorch.org/docs/stable/generated/torch.abs.html public fun TensorType.abs(): TensorType - public fun TensorType.absAssign(): Unit //https://pytorch.org/docs/stable/generated/torch.sum.html public fun TensorType.sum(): TensorType - public fun TensorType.sumAssign(): Unit } // https://proofwiki.org/wiki/Definition:Division_Algebra @@ -72,11 +70,9 @@ public interface TensorPartialDivisionAlgebra //https://pytorch.org/docs/stable/generated/torch.exp.html public fun TensorType.exp(): TensorType - public fun TensorType.expAssign(): Unit //https://pytorch.org/docs/stable/generated/torch.log.html public fun TensorType.log(): TensorType - public fun TensorType.logAssign(): Unit // todo change type of pivots //https://pytorch.org/docs/stable/generated/torch.lu.html diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt index 8e95922b8..c9fd6bb3a 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt @@ -2,6 +2,8 @@ package space.kscience.kmath.tensors import space.kscience.kmath.structures.array import kotlin.test.Test +import kotlin.test.assertFails +import kotlin.test.assertFailsWith import kotlin.test.assertTrue class TestRealTensorAlgebra { @@ -48,20 +50,48 @@ class TestRealTensorAlgebra { } @Test - fun transposeAssign1x2() = RealTensorAlgebra { - val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0)) - tensor.transposeAssign(0, 1) + fun broadcastShapes() = RealTensorAlgebra { + assertTrue(this.broadcastShapes( + intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1) + ) contentEquals intArrayOf(1, 2, 3)) - assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0)) - assertTrue(tensor.shape contentEquals intArrayOf(2, 1)) + assertTrue(this.broadcastShapes( + intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7) + ) contentEquals intArrayOf(5, 6, 7)) } @Test - fun transposeAssign2x3() = RealTensorAlgebra { - val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) - tensor.transposeAssign(1, 0) + fun broadcastTensors() = RealTensorAlgebra { + val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0)) + val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) - assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) - assertTrue(tensor.shape contentEquals intArrayOf(3, 2)) + val res = this.broadcastTensors(tensor1, tensor2, tensor3) + + assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3)) + assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3)) + assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3)) + + assertTrue(res[0].buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + assertTrue(res[1].buffer.array contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) + assertTrue(res[2].buffer.array contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) } + + @Test + fun minusTensor() = RealTensorAlgebra { + val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0)) + val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) + + assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3)) + assertTrue((tensor2 - tensor1).buffer.array contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) + + assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3)) + assertTrue((tensor3 - tensor1).buffer.array + contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)) + + assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3)) + assertTrue((tensor3 - tensor2).buffer.array contentEquals doubleArrayOf(490.0, 480.0, 470.0)) + } + } \ No newline at end of file