added some tests for method dot
This commit is contained in:
parent
37922365b6
commit
eda477b2b5
@ -132,6 +132,27 @@ internal class TestDoubleTensorAlgebra {
|
||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||
))
|
||||
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
|
||||
|
||||
val oneDimTensor1 = fromArray(intArrayOf(3), doubleArrayOf(1.0, 2.0, 3.0))
|
||||
val oneDimTensor2 = fromArray(intArrayOf(3), doubleArrayOf(4.0, 5.0, 6.0))
|
||||
val resOneDimTensors = oneDimTensor1.dot(oneDimTensor2)
|
||||
assertTrue(resOneDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(32.0))
|
||||
assertTrue(resOneDimTensors.shape contentEquals intArrayOf(1))
|
||||
|
||||
val twoDimTensor1 = fromArray(intArrayOf(2, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0))
|
||||
val twoDimTensor2 = fromArray(intArrayOf(2, 2), doubleArrayOf(5.0, 6.0, 7.0, 8.0))
|
||||
val resTwoDimTensors = twoDimTensor1.dot(twoDimTensor2)
|
||||
assertTrue(resTwoDimTensors.mutableBuffer.array() contentEquals doubleArrayOf(19.0, 22.0, 43.0, 50.0))
|
||||
assertTrue(resTwoDimTensors.shape contentEquals intArrayOf(2, 2))
|
||||
|
||||
val oneDimTensor3 = fromArray(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||
val resOneDimTensorOnTwoDimTensor = oneDimTensor3.dot(twoDimTensor1)
|
||||
assertTrue(resOneDimTensorOnTwoDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(7.0, 10.0))
|
||||
assertTrue(resOneDimTensorOnTwoDimTensor.shape contentEquals intArrayOf(2))
|
||||
|
||||
val resTwoDimTensorOnOneDimTensor = twoDimTensor1.dot(oneDimTensor3)
|
||||
assertTrue(resTwoDimTensorOnOneDimTensor.mutableBuffer.array() contentEquals doubleArrayOf(5.0, 11.0))
|
||||
assertTrue(resTwoDimTensorOnOneDimTensor.shape contentEquals intArrayOf(2))
|
||||
}
|
||||
|
||||
@Test
|
||||
|
Loading…
Reference in New Issue
Block a user