From db06d10cc28a12bf7529375939878c53cdeb96c2 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 28 Oct 2021 19:25:10 +0100 Subject: [PATCH] dot fixed for tensorflow --- .../space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt | 4 +++- .../kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt index e73620d01..90485a9ad 100644 --- a/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt +++ b/kmath-tensorflow/src/main/kotlin/space/kscience/kmath/tensorflow/TensorFlowAlgebra.kt @@ -180,7 +180,9 @@ public abstract class TensorFlowAlgebra> internal cons } override fun StructureND.dot(other: StructureND): TensorFlowOutput = biOp(other) { l, r -> - ops.linalg.matMul(l, r) + ops.linalg.matMul( + if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l,ops.constant(0)) else l, + if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r,ops.constant(-1)) else r) } override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): Tensor = ops.run { diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 8c445cf2d..672e37b98 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -208,7 +208,7 @@ public interface TensorAlgebra> : RingOpsND { * * 3. If the first argument is 1-dimensional and the second argument is 2-dimensional, * a 1 is prepended to its dimension for the purpose of the matrix multiply. - * After the matrix multiply, the prepended dimension is removed. + * After the matrix multiply, depending on the implementation the prepended dimension might be removed. * * 4. If the first argument is 2-dimensional and the second argument is 1-dimensional, * the matrix-vector product is returned.