From 5402ba47c90be0f24f4642dd78adb5ed208c4c08 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 3 Aug 2022 18:10:44 +0300 Subject: [PATCH] Restrict tensor dot ot vectors and matrices only. Introduce bdot to Double TensorAlgebra for broadcasting operations. --- CHANGELOG.md | 3 +- .../kmath/tensors/api/TensorAlgebra.kt | 11 +---- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 46 +++++++++++++++++-- .../core/TestDoubleLinearOpsAlgebra.kt | 12 ++--- .../tensors/core/TestDoubleTensorAlgebra.kt | 4 +- 5 files changed, 52 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66d83ec34..03fb4cf8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,9 @@ - Algebra now has an obligatory `bufferFactory` (#477). ### Changed -- Kotlin 1.7 +- Kotlin 1.7.20 - `LazyStructure` `deffered` -> `async` to comply with coroutines code style +- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `bdot` operation is added to `DoubleTensorAlgebra`. ### Deprecated diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 86d4eaa4e..1b11dc54f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -213,16 +213,7 @@ public interface TensorAlgebra> : RingOpsND { * 4. If the first argument is 2-dimensional and the second argument is 1-dimensional, * the matrix-vector product is returned. * - * 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), - * then a batched matrix multiply is returned. If the first argument is 1-dimensional, - * a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. - * If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix - * multiple and removed after. - * The non-matrix (i.e., batch) dimensions are broadcast (and thus must be broadcastable). - * For example, if `input` is a (j × 1 × n × n) tensor and `other` is a - * (k × n × n) tensor, out will be a (j × k × n × n) tensor. - * - * For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html + * Otherwise, throw an exception. * * @param other tensor to be multiplied. * @return a mathematical product of two tensors. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index e9dc34748..af4150f5b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -381,7 +381,36 @@ public open class DoubleTensorAlgebra : override fun Tensor.viewAs(other: StructureND): DoubleTensor = tensor.view(other.shape) - override infix fun StructureND.dot(other: StructureND): DoubleTensor { + /** + * Broadcasting Matrix product of two tensors. + * + * The behavior depends on the dimensionality of the tensors as follows: + * 1. If both tensors are 1-dimensional, the dot product (scalar) is returned. + * + * 2. If both arguments are 2-dimensional, the matrix-matrix product is returned. + * + * 3. If the first argument is 1-dimensional and the second argument is 2-dimensional, + * a 1 is prepended to its dimension for the purpose of the matrix multiply. + * After the matrix multiply, depending on the implementation the prepended dimension might be removed. + * + * 4. If the first argument is 2-dimensional and the second argument is 1-dimensional, + * the matrix-vector product is returned. + * + * 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), + * then a batched matrix multiply is returned. If the first argument is 1-dimensional, + * a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. + * If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix + * multiple and removed after. + * The non-matrix (i.e., batch) dimensions are broadcast (and thus must be broadcastable). + * For example, if `input` is a (j × 1 × n × n) tensor and `other` is a + * (k × n × n) tensor, out will be a (j × k × n × n) tensor. + * + * For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html + * + * @param other tensor to be multiplied. + * @return a mathematical product of two tensors. + */ + public infix fun StructureND.bdot(other: StructureND): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) } @@ -430,6 +459,11 @@ public open class DoubleTensorAlgebra : } } + override fun StructureND.dot(other: StructureND): DoubleTensor { + return if (dimension in 0..2 && other.dimension in 0..2) bdot(other) + else error("Only vectors and matrices are allowed in non-broadcasting dot operation") + } + override fun diagonalEmbedding( diagonalEntries: Tensor, offset: Int, @@ -587,7 +621,8 @@ public open class DoubleTensorAlgebra : val resNumElements = resShape.reduce(Int::times) val init = foldFunction(DoubleArray(1) { 0.0 }) val resTensor = BufferedTensor(resShape, - MutableBuffer.auto(resNumElements) { init }, 0) + MutableBuffer.auto(resNumElements) { init }, 0 + ) for (index in resTensor.indices) { val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() @@ -882,7 +917,8 @@ public open class DoubleTensorAlgebra : return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun StructureND.symEig(): Pair = symEigJacobi(maxIteration = 50, epsilon = 1e-15) + override fun StructureND.symEig(): Pair = + symEigJacobi(maxIteration = 50, epsilon = 1e-15) /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, @@ -909,7 +945,7 @@ public open class DoubleTensorAlgebra : val (u, s, v) = tensor.svd(epsilon) val shp = s.shape + intArrayOf(1) - val utv = u.transpose() dot v + val utv = u.transpose() bdot v val n = s.shape.last() for (matrix in utv.matrixSequence()) { matrix.as2D().cleanSym(n) @@ -951,7 +987,7 @@ public open class DoubleTensorAlgebra : private fun MutableStructure2D.jacobiHelper( maxIteration: Int, - epsilon: Double + epsilon: Double, ): Pair, Structure2D> { val n = this.shape[0] val A_ = this.copy() diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index e025d4b71..6ae7ae8ef 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -115,7 +115,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { assertTrue { q.shape contentEquals shape } assertTrue { r.shape contentEquals shape } - assertTrue((q dot r).eq(tensor)) + assertTrue((q bdot r).eq(tensor)) } @@ -136,17 +136,17 @@ internal class TestDoubleLinearOpsTensorAlgebra { assertTrue { l.shape contentEquals shape } assertTrue { u.shape contentEquals shape } - assertTrue((p dot tensor).eq(l dot u)) + assertTrue((p bdot tensor).eq(l bdot u)) } @Test fun testCholesky() = DoubleTensorAlgebra { val tensor = randomNormal(intArrayOf(2, 5, 5), 0) - val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding( + val sigma = (tensor bdot tensor.transpose()) + diagonalEmbedding( fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 }) ) val low = sigma.cholesky() - val sigmChol = low dot low.transpose() + val sigmChol = low bdot low.transpose() assertTrue(sigma.eq(sigmChol)) } @@ -171,7 +171,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { fun testBatchedSVD() = DoubleTensorAlgebra { val tensor = randomNormal(intArrayOf(2, 5, 3), 0) val (tensorU, tensorS, tensorV) = tensor.svd() - val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) + val tensorSVD = tensorU bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose()) assertTrue(tensor.eq(tensorSVD)) } @@ -180,7 +180,7 @@ internal class TestDoubleLinearOpsTensorAlgebra { val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0) val tensorSigma = tensor + tensor.transpose() val (tensorS, tensorV) = tensorSigma.symEig() - val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) + val tensorSigmaCalc = tensorV bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose()) assertTrue(tensorSigma.eq(tensorSigmaCalc)) } diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 205ae2fee..2f3c8e2de 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -114,7 +114,7 @@ internal class TestDoubleTensorAlgebra { assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0)) assertTrue(res12.shape contentEquals intArrayOf(2)) - val res32 = tensor3.dot(tensor2) + val res32 = tensor3.bdot(tensor2) assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0)) assertTrue(res32.shape contentEquals intArrayOf(1, 1)) @@ -126,7 +126,7 @@ internal class TestDoubleTensorAlgebra { assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0)) assertTrue(res11.shape contentEquals intArrayOf(2, 2)) - val res45 = tensor4.dot(tensor5) + val res45 = tensor4.bdot(tensor5) assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf( 36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0, 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0