From 3535e512483b74f00867858b18e38a2150f3451c Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Fri, 19 Mar 2021 19:52:58 +0000 Subject: [PATCH] Broadcasting as its own algebra --- .../core/BroadcastDoubleTensorAlgebra.kt | 2 +- .../kmath/tensors/core/TestBroadcasting.kt | 17 ++++------------- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index a4767a612..20f64f469 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -80,7 +80,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } -public inline fun broadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R = +public inline fun BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorAlgebra.() -> R): R = BroadcastDoubleTensorAlgebra().block() diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt index 2633229ea..73e3993a1 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestBroadcasting.kt @@ -48,23 +48,14 @@ class TestBroadcasting { } @Test - fun minusTensor() = DoubleTensorAlgebra { + fun minusTensor() = BroadcastDoubleTensorAlgebra { val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0)) val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) - val tensor21 = broadcastDoubleTensorAlgebra { - tensor2 - tensor1 - } - - val tensor31 = broadcastDoubleTensorAlgebra { - tensor3 - tensor1 - } - - val tensor32 = broadcastDoubleTensorAlgebra { - tensor3 - tensor2 - } - + val tensor21 = tensor2 - tensor1 + val tensor31 = tensor3 - tensor1 + val tensor32 = tensor3 - tensor2 assertTrue(tensor21.shape contentEquals intArrayOf(2, 3)) assertTrue(tensor21.buffer.array() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))