add broadcast to functions
This commit is contained in:
parent
0553a28ee8
commit
99ee5aa54a
@ -69,10 +69,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
|
||||
//todo should be change with broadcasting
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] +=
|
||||
other.buffer.array()[this.bufferStart + i]
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,41 +107,45 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
|
||||
TODO("Alya")
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] -=
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Double.times(other: DoubleTensor): DoubleTensor {
|
||||
//todo should be change with broadcasting
|
||||
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
|
||||
other.buffer.array()[other.bufferStart + i] * this
|
||||
}
|
||||
return DoubleTensor(other.shape, resBuffer)
|
||||
}
|
||||
|
||||
//todo should be change with broadcasting
|
||||
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
|
||||
|
||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||
//todo should be change with broadcasting
|
||||
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
|
||||
this.buffer.array()[other.bufferStart + i] *
|
||||
other.buffer.array()[other.bufferStart + i]
|
||||
val broadcast = broadcastTensors(this, other)
|
||||
val newThis = broadcast[0]
|
||||
val newOther = broadcast[1]
|
||||
|
||||
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
|
||||
newThis.buffer.array()[newOther.bufferStart + i] *
|
||||
newOther.buffer.array()[newOther.bufferStart + i]
|
||||
}
|
||||
return DoubleTensor(this.shape, resBuffer)
|
||||
return DoubleTensor(newThis.shape, resBuffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.timesAssign(value: Double) {
|
||||
//todo should be change with broadcasting
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] *= value
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
|
||||
//todo should be change with broadcasting
|
||||
val newOther = broadcastTo(other, this.shape)
|
||||
for (i in 0 until this.strides.linearSize) {
|
||||
this.buffer.array()[this.bufferStart + i] *=
|
||||
other.buffer.array()[this.bufferStart + i]
|
||||
newOther.buffer.array()[this.bufferStart + i]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,43 @@ internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||
return totalShape
|
||||
}
|
||||
|
||||
internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
|
||||
if (tensor.shape.size > newShape.size) {
|
||||
throw RuntimeException("Tensor is not compatible with the new shape")
|
||||
}
|
||||
|
||||
val n = newShape.reduce { acc, i -> acc * i }
|
||||
val resTensor = DoubleTensor(newShape, DoubleArray(n))
|
||||
|
||||
for (i in tensor.shape.indices) {
|
||||
val curDim = tensor.shape[i]
|
||||
val offset = newShape.size - tensor.shape.size
|
||||
if (curDim != 1 && newShape[i + offset] != curDim) {
|
||||
throw RuntimeException("Tensor is not compatible with the new shape and cannot be broadcast")
|
||||
}
|
||||
}
|
||||
|
||||
for (linearIndex in 0 until n) {
|
||||
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
||||
val curMultiIndex = tensor.shape.copyOf()
|
||||
|
||||
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||
|
||||
for (i in curMultiIndex.indices) {
|
||||
if (curMultiIndex[i] != 1) {
|
||||
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||
} else {
|
||||
curMultiIndex[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||
resTensor.buffer.array()[linearIndex] =
|
||||
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
|
||||
}
|
||||
return resTensor
|
||||
}
|
||||
|
||||
internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||
val n = totalShape.reduce { acc, i -> acc * i }
|
||||
|
@ -58,6 +58,16 @@ class TestDoubleTensorAlgebra {
|
||||
) contentEquals intArrayOf(5, 6, 7))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTo() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
|
||||
val res = broadcastTo(tensor2, tensor1.shape)
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun broadcastTensors() = DoubleTensorAlgebra {
|
||||
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
|
Loading…
Reference in New Issue
Block a user