fix div + simple tests
This commit is contained in:
parent
4f593aec63
commit
2c001cb1b3
@ -180,12 +180,17 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
|
|
||||||
override fun Double.div(other: TensorStructure<Double>): DoubleTensor {
|
override fun Double.div(other: TensorStructure<Double>): DoubleTensor {
|
||||||
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
|
val resBuffer = DoubleArray(other.tensor.numElements) { i ->
|
||||||
other.tensor.buffer.array()[other.tensor.bufferStart + i] / this
|
this / other.tensor.buffer.array()[other.tensor.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun TensorStructure<Double>.div(value: Double): DoubleTensor = value / tensor
|
override fun TensorStructure<Double>.div(value: Double): DoubleTensor {
|
||||||
|
val resBuffer = DoubleArray(tensor.numElements) { i ->
|
||||||
|
tensor.buffer.array()[tensor.bufferStart + i] / value
|
||||||
|
}
|
||||||
|
return DoubleTensor(shape, resBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
|
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
|
||||||
checkShapesCompatible(tensor, other)
|
checkShapesCompatible(tensor, other)
|
||||||
|
@ -15,6 +15,20 @@ class TestDoubleTensorAlgebra {
|
|||||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0))
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(11.0, 12.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun doubleDiv() = DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(intArrayOf(2), doubleArrayOf(2.0, 4.0))
|
||||||
|
val res = 2.0/tensor
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(1.0, 0.5))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun divDouble() = DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(intArrayOf(2), doubleArrayOf(10.0, 5.0))
|
||||||
|
val res = tensor / 2.5
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(4.0, 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun transpose1x1() = DoubleTensorAlgebra {
|
fun transpose1x1() = DoubleTensorAlgebra {
|
||||||
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))
|
val tensor = fromArray(intArrayOf(1), doubleArrayOf(0.0))
|
||||||
|
Loading…
Reference in New Issue
Block a user