Add more test cases for tensor operations #475
@ -49,6 +49,18 @@ internal class TestDoubleTensor {
|
|||||||
assertTrue { tensor.rowsByIndices(intArrayOf(0, 1)) eq tensor }
|
assertTrue { tensor.rowsByIndices(intArrayOf(0, 1)) eq tensor }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTimes() = DoubleTensorAlgebra {
|
||||||
|
val shape = intArrayOf(2, 2)
|
||||||
|
val buffer = doubleArrayOf(1.0, 2.0, -3.0, 4.0)
|
||||||
|
val tensor = DoubleTensor(shape, buffer)
|
||||||
|
val value = 3
|
||||||
|
assertTrue { tensor.times(value).toBufferedTensor() eq DoubleTensor(shape, buffer.map { x -> 3 * x }.toDoubleArray()) }
|
||||||
|
val buffer2 = doubleArrayOf(7.0, -8.0, -5.0, 2.0)
|
||||||
|
val tensor2 = DoubleTensor(shape, buffer2)
|
||||||
|
assertTrue {tensor.times(tensor2).toBufferedTensor() eq DoubleTensor(shape, doubleArrayOf(7.0, -16.0, 15.0, 8.0)) }
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testValue() = DoubleTensorAlgebra {
|
fun testValue() = DoubleTensorAlgebra {
|
||||||
val value = 12.5
|
val value = 12.5
|
||||||
|
Loading…
Reference in New Issue
Block a user