From 2c001cb1b32f25dc4c245bc32b4046bf78e85cca Mon Sep 17 00:00:00 2001 From: Andrei Kislitsyn Date: Mon, 26 Apr 2021 17:07:49 +0300 Subject: [PATCH] fix div + simple tests --- .../tensors/core/algebras/DoubleTensorAlgebra.kt | 9 +++++++-- .../kmath/tensors/core/TestDoubleTensorAlgebra.kt | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt index b61a484d0..7414a4469 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt @@ -180,12 +180,17 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { override fun Double.div(other: TensorStructure): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> - other.tensor.buffer.array()[other.tensor.bufferStart + i] / this + this / other.tensor.buffer.array()[other.tensor.bufferStart + i] } return DoubleTensor(other.shape, resBuffer) } - override fun TensorStructure.div(value: Double): DoubleTensor = value / tensor + override fun TensorStructure.div(value: Double): DoubleTensor { + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[tensor.bufferStart + i] / value + } + return DoubleTensor(shape, resBuffer) + } override fun TensorStructure.div(other: TensorStructure): DoubleTensor { checkShapesCompatible(tensor, other) diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 10a9176dc..1333a7a1f 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -15,6 +15,20 @@ class TestDoubleTensorAlgebra { assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0)) } + @Test + fun doubleDiv() = DoubleTensorAlgebra { + val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0)) + val res = 2.0/tensor + assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 0.5)) + } + + @Test + fun divDouble() = DoubleTensorAlgebra { + val tensor = fromArray(intArrayOf(2), doubleArrayOf(10.0, 5.0)) + val res = tensor / 2.5 + assertTrue(res.buffer.array() contentEquals doubleArrayOf(4.0, 2.0)) + } + @Test fun transpose1x1() = DoubleTensorAlgebra { val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))