KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
2 changed files with 5 additions and 14 deletions
Showing only changes of commit 3535e51248 - Show all commits

View File

@ -80,7 +80,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
public inline fun <R> broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R = public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R =
BroadcastDoubleTensorAlgebra().block() BroadcastDoubleTensorAlgebra().block()

View File

@ -48,23 +48,14 @@ class TestBroadcasting {
} }
@Test @Test
fun minusTensor() = DoubleTensorAlgebra { fun minusTensor() = BroadcastDoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0)) val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val tensor21 = broadcastDoubleTensorAlgebra { val tensor21 = tensor2 - tensor1
tensor2 - tensor1 val tensor31 = tensor3 - tensor1
} val tensor32 = tensor3 - tensor2
val tensor31 = broadcastDoubleTensorAlgebra {
tensor3 - tensor1
}
val tensor32 = broadcastDoubleTensorAlgebra {
tensor3 - tensor2
}
assertTrue(tensor21.shape contentEquals intArrayOf(2, 3)) assertTrue(tensor21.shape contentEquals intArrayOf(2, 3))
assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))