From 3ca94ec4ec235ffaabfbca874a7fd2a845fcdcec Mon Sep 17 00:00:00 2001 From: Margarita Date: Tue, 24 May 2022 23:44:09 +0300 Subject: [PATCH] added some tests for method dot --- .../tensors/core/TestDoubleTensorAlgebra.kt | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 205ae2fee..8657a8dd9 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -132,6 +132,27 @@ internal class TestDoubleTensorAlgebra { 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 )) assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3)) + + val oneDimTensor1 = fromArray(intArrayOf(3), doubleArrayOf(1.0, 2.0, 3.0)) + val oneDimTensor2 = fromArray(intArrayOf(3), doubleArrayOf(4.0, 5.0, 6.0)) + val resOneDimTensors = oneDimTensor1.dot(oneDimTensor2) + assertTrue(resOneDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(32.0)) + assertTrue(resOneDimTensors.shape contentEquals intArrayOf(1)) + + val twoDimTensor1 = fromArray(intArrayOf(2, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0)) + val twoDimTensor2 = fromArray(intArrayOf(2, 2), doubleArrayOf(5.0, 6.0, 7.0, 8.0)) + val resTwoDimTensors = twoDimTensor1.dot(twoDimTensor2) + assertTrue(resTwoDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(19.0, 22.0, 43.0, 50.0)) + assertTrue(resTwoDimTensors.shape contentEquals intArrayOf(2, 2)) + + val oneDimTensor3 = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0)) + val resOneDimTensorOnTwoDimTensor = oneDimTensor3.dot(twoDimTensor1) + assertTrue(resOneDimTensorOnTwoDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(7.0, 10.0)) + assertTrue(resOneDimTensorOnTwoDimTensor.shape contentEquals intArrayOf(2)) + + val resTwoDimTensorOnOneDimTensor = twoDimTensor1.dot(oneDimTensor3) + assertTrue(resTwoDimTensorOnOneDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(5.0, 11.0)) + assertTrue(resTwoDimTensorOnOneDimTensor.shape contentEquals intArrayOf(2)) } @Test