added some tests for method dot

This commit is contained in:
Margarita 2022-05-24 23:44:09 +03:00
parent 37922365b6
commit eda477b2b5

View File

@ -132,6 +132,27 @@ internal class TestDoubleTensorAlgebra {
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
)) ))
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3)) assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
val oneDimTensor1 = fromArray(intArrayOf(3), doubleArrayOf(1.0, 2.0, 3.0))
val oneDimTensor2 = fromArray(intArrayOf(3), doubleArrayOf(4.0, 5.0, 6.0))
val resOneDimTensors = oneDimTensor1.dot(oneDimTensor2)
assertTrue(resOneDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(32.0))
assertTrue(resOneDimTensors.shape contentEquals intArrayOf(1))
val twoDimTensor1 = fromArray(intArrayOf(2, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0))
val twoDimTensor2 = fromArray(intArrayOf(2, 2), doubleArrayOf(5.0, 6.0, 7.0, 8.0))
val resTwoDimTensors = twoDimTensor1.dot(twoDimTensor2)
assertTrue(resTwoDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(19.0, 22.0, 43.0, 50.0))
assertTrue(resTwoDimTensors.shape contentEquals intArrayOf(2, 2))
val oneDimTensor3 = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val resOneDimTensorOnTwoDimTensor = oneDimTensor3.dot(twoDimTensor1)
assertTrue(resOneDimTensorOnTwoDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(7.0, 10.0))
assertTrue(resOneDimTensorOnTwoDimTensor.shape contentEquals intArrayOf(2))
val resTwoDimTensorOnOneDimTensor = twoDimTensor1.dot(oneDimTensor3)
assertTrue(resTwoDimTensorOnOneDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(5.0, 11.0))
assertTrue(resTwoDimTensorOnOneDimTensor.shape contentEquals intArrayOf(2))
} }
@Test @Test