KMP library for tensors #300
@ -80,7 +80,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
|
public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
|
||||||
BroadcastDoubleTensorAlgebra().block()
|
BroadcastDoubleTensorAlgebra().block()
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,23 +48,14 @@ class TestBroadcasting {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun minusTensor() = DoubleTensorAlgebra {
|
fun minusTensor() = BroadcastDoubleTensorAlgebra {
|
||||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
val tensor21 = broadcastDoubleTensorAlgebra {
|
val tensor21 = tensor2 - tensor1
|
||||||
tensor2 - tensor1
|
val tensor31 = tensor3 - tensor1
|
||||||
}
|
val tensor32 = tensor3 - tensor2
|
||||||
|
|
||||||
val tensor31 = broadcastDoubleTensorAlgebra {
|
|
||||||
tensor3 - tensor1
|
|
||||||
}
|
|
||||||
|
|
||||||
val tensor32 = broadcastDoubleTensorAlgebra {
|
|
||||||
tensor3 - tensor2
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
|
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
|
||||||
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||||
|
Loading…
Reference in New Issue
Block a user