added some tests for method dot
This commit is contained in:
parent
37922365b6
commit
eda477b2b5
@ -132,6 +132,27 @@ internal class TestDoubleTensorAlgebra {
|
|||||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||||
))
|
))
|
||||||
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
|
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
|
||||||
|
|
||||||
|
val oneDimTensor1 = fromArray(intArrayOf(3), doubleArrayOf(1.0, 2.0, 3.0))
|
||||||
|
val oneDimTensor2 = fromArray(intArrayOf(3), doubleArrayOf(4.0, 5.0, 6.0))
|
||||||
|
val resOneDimTensors = oneDimTensor1.dot(oneDimTensor2)
|
||||||
|
assertTrue(resOneDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(32.0))
|
||||||
|
assertTrue(resOneDimTensors.shape contentEquals intArrayOf(1))
|
||||||
|
|
||||||
|
val twoDimTensor1 = fromArray(intArrayOf(2, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0))
|
||||||
|
val twoDimTensor2 = fromArray(intArrayOf(2, 2), doubleArrayOf(5.0, 6.0, 7.0, 8.0))
|
||||||
|
val resTwoDimTensors = twoDimTensor1.dot(twoDimTensor2)
|
||||||
|
assertTrue(resTwoDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(19.0, 22.0, 43.0, 50.0))
|
||||||
|
assertTrue(resTwoDimTensors.shape contentEquals intArrayOf(2, 2))
|
||||||
|
|
||||||
|
val oneDimTensor3 = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||||
|
val resOneDimTensorOnTwoDimTensor = oneDimTensor3.dot(twoDimTensor1)
|
||||||
|
assertTrue(resOneDimTensorOnTwoDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(7.0, 10.0))
|
||||||
|
assertTrue(resOneDimTensorOnTwoDimTensor.shape contentEquals intArrayOf(2))
|
||||||
|
|
||||||
|
val resTwoDimTensorOnOneDimTensor = twoDimTensor1.dot(oneDimTensor3)
|
||||||
|
assertTrue(resTwoDimTensorOnOneDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(5.0, 11.0))
|
||||||
|
assertTrue(resTwoDimTensorOnOneDimTensor.shape contentEquals intArrayOf(2))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user