forked from kscience/kmath
add broadcast to functions
This commit is contained in:
parent
0553a28ee8
commit
99ee5aa54a
@ -69,10 +69,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||||
//todo should be change with broadcasting
|
val newOther = broadcastTo(other, this.shape)
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.array()[this.bufferStart + i] +=
|
this.buffer.array()[this.bufferStart + i] +=
|
||||||
other.buffer.array()[this.bufferStart + i]
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,41 +107,45 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
||||||
TODO("Alya")
|
val newOther = broadcastTo(other, this.shape)
|
||||||
|
for (i in 0 until this.strides.linearSize) {
|
||||||
|
this.buffer.array()[this.bufferStart + i] -=
|
||||||
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
||||||
//todo should be change with broadcasting
|
|
||||||
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||||
other.buffer.array()[other.bufferStart + i] * this
|
other.buffer.array()[other.bufferStart + i] * this
|
||||||
}
|
}
|
||||||
return DoubleTensor(other.shape, resBuffer)
|
return DoubleTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
//todo should be change with broadcasting
|
|
||||||
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
|
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
|
||||||
|
|
||||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||||
//todo should be change with broadcasting
|
val broadcast = broadcastTensors(this, other)
|
||||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
val newThis = broadcast[0]
|
||||||
this.buffer.array()[other.bufferStart + i] *
|
val newOther = broadcast[1]
|
||||||
other.buffer.array()[other.bufferStart + i]
|
|
||||||
|
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||||
|
newThis.buffer.array()[newOther.bufferStart + i] *
|
||||||
|
newOther.buffer.array()[newOther.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(value: Double) {
|
override fun DoubleTensor.timesAssign(value: Double) {
|
||||||
//todo should be change with broadcasting
|
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.array()[this.bufferStart + i] *= value
|
this.buffer.array()[this.bufferStart + i] *= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||||
//todo should be change with broadcasting
|
val newOther = broadcastTo(other, this.shape)
|
||||||
for (i in 0 until this.strides.linearSize) {
|
for (i in 0 until this.strides.linearSize) {
|
||||||
this.buffer.array()[this.bufferStart + i] *=
|
this.buffer.array()[this.bufferStart + i] *=
|
||||||
other.buffer.array()[this.bufferStart + i]
|
newOther.buffer.array()[this.bufferStart + i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,6 +32,43 @@ internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
|||||||
return totalShape
|
return totalShape
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
|
||||||
|
if (tensor.shape.size > newShape.size) {
|
||||||
|
throw RuntimeException("Tensor is not compatible with the new shape")
|
||||||
|
}
|
||||||
|
|
||||||
|
val n = newShape.reduce { acc, i -> acc * i }
|
||||||
|
val resTensor = DoubleTensor(newShape, DoubleArray(n))
|
||||||
|
|
||||||
|
for (i in tensor.shape.indices) {
|
||||||
|
val curDim = tensor.shape[i]
|
||||||
|
val offset = newShape.size - tensor.shape.size
|
||||||
|
if (curDim != 1 && newShape[i + offset] != curDim) {
|
||||||
|
throw RuntimeException("Tensor is not compatible with the new shape and cannot be broadcast")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (linearIndex in 0 until n) {
|
||||||
|
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
||||||
|
val curMultiIndex = tensor.shape.copyOf()
|
||||||
|
|
||||||
|
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||||
|
|
||||||
|
for (i in curMultiIndex.indices) {
|
||||||
|
if (curMultiIndex[i] != 1) {
|
||||||
|
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||||
|
} else {
|
||||||
|
curMultiIndex[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||||
|
resTensor.buffer.array()[linearIndex] =
|
||||||
|
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
|
||||||
|
}
|
||||||
|
return resTensor
|
||||||
|
}
|
||||||
|
|
||||||
internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||||
val n = totalShape.reduce { acc, i -> acc * i }
|
val n = totalShape.reduce { acc, i -> acc * i }
|
||||||
|
@ -58,6 +58,16 @@ class TestDoubleTensorAlgebra {
|
|||||||
) contentEquals intArrayOf(5, 6, 7))
|
) contentEquals intArrayOf(5, 6, 7))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastTo() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
|
||||||
|
val res = broadcastTo(tensor2, tensor1.shape)
|
||||||
|
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||||
|
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun broadcastTensors() = DoubleTensorAlgebra {
|
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user