diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt index 54e410ade..e458d0606 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/NDStructure.kt @@ -174,6 +174,11 @@ public interface Strides { */ public fun index(offset: Int): IntArray + /** + * Get next multidimensional index from the current multidimensional index + */ + public fun nextIndex(index: IntArray): IntArray + /** * The size of linear buffer to accommodate all elements of ND-structure corresponding to strides */ @@ -232,6 +237,10 @@ public class DefaultStrides private constructor(override val shape: IntArray) : return res } + override fun nextIndex(index: IntArray): IntArray { + TODO("Not yet implemented") + } + override fun equals(other: Any?): Boolean { if (this === other) return true if (other !is DefaultStrides) return false diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt index 366acfb62..a5c00e8ec 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt @@ -164,11 +164,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra= shape[current]) { + carry = 1 + res[current] = 0 + } + current-- + } while(carry != 0 && current >= 0) + + return res +} + public class TensorStrides(override val shape: IntArray): Strides @@ -47,6 +64,9 @@ public class TensorStrides(override val shape: IntArray): Strides override fun index(offset: Int): IntArray = indexFromOffset(offset, strides, shape.size) + override fun nextIndex(index: IntArray): IntArray = + nextIndex(index, shape, shape.size) + override val linearSize: Int get() = shape.reduce(Int::times) } \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt index 86caa0338..8e95922b8 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/TestRealTensorAlgebra.kt @@ -13,4 +13,55 @@ class TestRealTensorAlgebra { assertTrue(res.buffer.array contentEquals doubleArrayOf(11.0,12.0)) } + @Test + fun transpose1x1() = RealTensorAlgebra { + val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0)) + val res = tensor.transpose(0, 0) + + assertTrue(res.buffer.array contentEquals doubleArrayOf(0.0)) + assertTrue(res.shape contentEquals intArrayOf(1)) + } + + @Test + fun transpose3x2() = RealTensorAlgebra { + val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val res = tensor.transpose(1, 0) + + assertTrue(res.buffer.array contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) + assertTrue(res.shape contentEquals intArrayOf(2, 3)) + } + + @Test + fun transpose1x2x3() = RealTensorAlgebra { + val tensor = RealTensor(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val res01 = tensor.transpose(0, 1) + val res02 = tensor.transpose(0, 2) + val res12 = tensor.transpose(1, 2) + + assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3)) + assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1)) + assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2)) + + assertTrue(res01.buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) + assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) + } + + @Test + fun transposeAssign1x2() = RealTensorAlgebra { + val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0)) + tensor.transposeAssign(0, 1) + + assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0)) + assertTrue(tensor.shape contentEquals intArrayOf(2, 1)) + } + + @Test + fun transposeAssign2x3() = RealTensorAlgebra { + val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + tensor.transposeAssign(1, 0) + + assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) + assertTrue(tensor.shape contentEquals intArrayOf(3, 2)) + } } \ No newline at end of file