Broadcasting as its own algebra

This commit is contained in:
Roland Grinis 2021-03-19 19:52:58 +00:00
parent 274be61330
commit 3535e51248
2 changed files with 5 additions and 14 deletions

View File

@ -80,7 +80,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
}
public inline fun <R> broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
BroadcastDoubleTensorAlgebra().block()

View File

@ -48,23 +48,14 @@ class TestBroadcasting {
}
@Test
fun minusTensor() = DoubleTensorAlgebra {
fun minusTensor() = BroadcastDoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val tensor21 = broadcastDoubleTensorAlgebra {
tensor2 - tensor1
}
val tensor31 = broadcastDoubleTensorAlgebra {
tensor3 - tensor1
}
val tensor32 = broadcastDoubleTensorAlgebra {
tensor3 - tensor2
}
val tensor21 = tensor2 - tensor1
val tensor31 = tensor3 - tensor1
val tensor32 = tensor3 - tensor2
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))