From 174f6566e116b43e70622785c58291532b170369 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Tue, 6 Apr 2021 12:07:39 +0100 Subject: [PATCH] Negative indices --- .../kscience/kmath/tensors/TensorAlgebra.kt | 4 ++-- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 20 +++++++++++-------- .../tensors/core/TensorLinearStructure.kt | 3 +++ .../kscience/kmath/tensors/core/utils.kt | 12 ++++++++++- .../core/TestDoubleLinearOpsAlgebra.kt | 4 ++-- .../tensors/core/TestDoubleTensorAlgebra.kt | 4 ++-- 6 files changed, 32 insertions(+), 15 deletions(-) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 966806de1..8fd1cf2ed 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -28,7 +28,7 @@ public interface TensorAlgebra> { public operator fun TensorType.get(i: Int): TensorType //https://pytorch.org/docs/stable/generated/torch.transpose.html - public fun TensorType.transpose(i: Int, j: Int): TensorType + public fun TensorType.transpose(i: Int = -2, j: Int = -1): TensorType //https://pytorch.org/docs/stable/tensor_view.html public fun TensorType.view(shape: IntArray): TensorType @@ -40,7 +40,7 @@ public interface TensorAlgebra> { //https://pytorch.org/docs/stable/generated/torch.diag_embed.html public fun diagonalEmbedding( diagonalEntries: TensorType, - offset: Int = 0, dim1: Int = 0, dim2: Int = 1 + offset: Int = 0, dim1: Int = -2, dim2: Int = -1 ): TensorType } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 53063a066..870dbe8a7 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -198,19 +198,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra n || dim2 > n) { + if (d1 > n || d2 > n) { throw RuntimeException("Dimension out of range") } - var lessDim = dim1 - var greaterDim = dim2 + var lessDim = d1 + var greaterDim = d2 var realOffset = offset if (lessDim > greaterDim) { realOffset *= -1 diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt index 97ce29657..47745c2be 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/TensorLinearStructure.kt @@ -73,6 +73,9 @@ public class TensorLinearStructure(public val shape: IntArray) public val size: Int get() = shape.reduce(Int::times) + public val dim: Int + get() = shape.size + public fun indices(): Sequence = (0 until size).asSequence().map { index(it) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt index 5fd3cfd28..785b59ede 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt @@ -39,4 +39,14 @@ internal fun Buffer.array(): DoubleArray = when (this) { internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray { val u = Random(seed) return (0 until n).map { sqrt(-2.0 * ln(u.nextDouble())) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray() -} \ No newline at end of file +} + +internal inline fun minusIndexFrom(n: Int, i: Int) : Int = if (i >= 0) i else { + val ii = n + i + check(ii >= 0) { + "Out of bound index $i for tensor of dim $n" + } + ii +} + +internal inline fun BufferedTensor.minusIndex(i: Int): Int = minusIndexFrom(this.linearStructure.dim, i) \ No newline at end of file diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 29632e771..56f9332f6 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -130,7 +130,7 @@ class TestDoubleLinearOpsTensorAlgebra { fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra { val tensor = randNormal(intArrayOf(7, 5, 3), 0) val (tensorU, tensorS, tensorV) = tensor.svd() - val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV) + val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV) assertTrue(tensor.eq(tensorSVD)) } @@ -139,7 +139,7 @@ class TestDoubleLinearOpsTensorAlgebra { val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0) val tensorSigma = tensor + tensor.transpose(1, 2) val (tensorS, tensorV) = tensorSigma.symEig() - val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS, 0,1,2) dot tensorV.transpose(1, 2)) + val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2)) assertTrue(tensorSigma.eq(tensorSigmaCalc)) } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 692db69af..fa7a8fd32 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -35,8 +35,8 @@ class TestDoubleTensorAlgebra { fun transpose1x2x3() = DoubleTensorAlgebra { val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val res01 = tensor.transpose(0, 1) - val res02 = tensor.transpose(0, 2) - val res12 = tensor.transpose(1, 2) + val res02 = tensor.transpose(-3, 2) + val res12 = tensor.transpose() assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3)) assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))