fix div + simple tests

This commit is contained in:
Andrei Kislitsyn 2021-04-26 17:07:49 +03:00
parent 4f593aec63
commit 2c001cb1b3
2 changed files with 21 additions and 2 deletions

View File

@ -180,12 +180,17 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
override fun Double.div(other: TensorStructure<Double>): DoubleTensor { override fun Double.div(other: TensorStructure<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.buffer.array()[other.tensor.bufferStart + i] / this this / other.tensor.buffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.div(value: Double): DoubleTensor = value / tensor override fun TensorStructure<Double>.div(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.buffer.array()[tensor.bufferStart + i] / value
}
return DoubleTensor(shape, resBuffer)
}
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)

View File

@ -15,6 +15,20 @@ class TestDoubleTensorAlgebra {
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0)) assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0))
} }
@Test
fun doubleDiv() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0))
val res = 2.0/tensor
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 0.5))
}
@Test
fun divDouble() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(2), doubleArrayOf(10.0, 5.0))
val res = tensor / 2.5
assertTrue(res.buffer.array() contentEquals doubleArrayOf(4.0, 2.0))
}
@Test @Test
fun transpose1x1() = DoubleTensorAlgebra { fun transpose1x1() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0)) val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))