v0.3.0-dev-9 #324
@ -80,7 +80,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
|
||||
public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
|
||||
BroadcastDoubleTensorAlgebra().block()
|
||||
|
||||
|
||||
|
@ -48,23 +48,14 @@ class TestBroadcasting {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun minusTensor() = DoubleTensorAlgebra {
|
||||
fun minusTensor() = BroadcastDoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
val tensor21 = broadcastDoubleTensorAlgebra {
|
||||
tensor2 - tensor1
|
||||
}
|
||||
|
||||
val tensor31 = broadcastDoubleTensorAlgebra {
|
||||
tensor3 - tensor1
|
||||
}
|
||||
|
||||
val tensor32 = broadcastDoubleTensorAlgebra {
|
||||
tensor3 - tensor2
|
||||
}
|
||||
|
||||
val tensor21 = tensor2 - tensor1
|
||||
val tensor31 = tensor3 - tensor1
|
||||
val tensor32 = tensor3 - tensor2
|
||||
|
||||
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
Loading…
Reference in New Issue
Block a user