From 3e98240b94b99170ffd8525d9b307ac6c715ebd6 Mon Sep 17 00:00:00 2001 From: AlyaNovikova Date: Thu, 1 Apr 2021 20:21:14 +0300 Subject: [PATCH] add function diagonalEmbedding with tests --- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 42 ++++++++++++++++++- .../tensors/core/TestDoubleTensorAlgebra.kt | 33 +++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 323ade45d..ab692c0f9 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -283,7 +283,47 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra n || dim2 > n) { + throw RuntimeException("Dimension out of range") + } + + var lessDim = dim1 + var greaterDim = dim2 + var realOffset = offset + if (lessDim > greaterDim) { + realOffset *= -1 + lessDim = greaterDim.also {greaterDim = lessDim} + } + + val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() + + intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) + + diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() + + intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) + + diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray() + val resTensor = zeros(resShape) + + for (i in 0 until diagonalEntries.linearStructure.size) { + val multiIndex = diagonalEntries.linearStructure.index(i) + + var offset1 = 0 + var offset2 = abs(realOffset) + if (realOffset < 0) { + offset1 = offset2.also {offset2 = offset1} + } + val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() + + intArrayOf(multiIndex[n - 1] + offset1) + + multiIndex.slice(lessDim until greaterDim - 1).toIntArray() + + intArrayOf(multiIndex[n - 1] + offset2) + + multiIndex.slice(greaterDim - 1 until n - 1).toIntArray() + + resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex] + } + + return resTensor } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 168d80a9d..692db69af 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -115,6 +115,39 @@ class TestDoubleTensorAlgebra { assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(5, 4, 2, 8, 3, 8, 5)) } + @Test + fun diagonalEmbedding() = DoubleTensorAlgebra { + val tensor1 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0)) + val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) + val tensor3 = zeros(intArrayOf(2, 3, 4, 5)) + + assertTrue(diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals + intArrayOf(2, 3, 4, 5, 5)) + assertTrue(diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals + intArrayOf(2, 3, 4, 6, 6)) + assertTrue(diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals + intArrayOf(7, 2, 3, 7, 4)) + + val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0) + assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3)) + assertTrue(diagonal1.buffer.array() contentEquals + doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)) + + val diagonal1_offset = diagonalEmbedding(tensor1, 1, 1, 0) + assertTrue(diagonal1_offset.shape contentEquals intArrayOf(4, 4)) + assertTrue(diagonal1_offset.buffer.array() contentEquals + doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)) + + val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2) + assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4)) + assertTrue(diagonal2.buffer.array() contentEquals + doubleArrayOf( + 0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, + 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0, + 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 6.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)) + } + @Test fun testContentEqual() = DoubleTensorAlgebra { //TODO()