add broadcast to functions

This commit is contained in:
AlyaNovikova 2021-03-16 14:57:19 +03:00
parent 0553a28ee8
commit 99ee5aa54a
3 changed files with 64 additions and 13 deletions

View File

@ -69,10 +69,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
}
override fun DoubleTensor.plusAssign(other: DoubleTensor) {
//todo should be change with broadcasting
val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] +=
other.buffer.array()[this.bufferStart + i]
newOther.buffer.array()[this.bufferStart + i]
}
}
@ -107,41 +107,45 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
}
override fun DoubleTensor.minusAssign(other: DoubleTensor) {
TODO("Alya")
val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] -=
newOther.buffer.array()[this.bufferStart + i]
}
}
override fun Double.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.array()[other.bufferStart + i] * this
}
return DoubleTensor(other.shape, resBuffer)
}
//todo should be change with broadcasting
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(this.strides.linearSize) { i ->
this.buffer.array()[other.bufferStart + i] *
other.buffer.array()[other.bufferStart + i]
val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.array()[newOther.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(this.shape, resBuffer)
return DoubleTensor(newThis.shape, resBuffer)
}
override fun DoubleTensor.timesAssign(value: Double) {
//todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] *= value
}
}
override fun DoubleTensor.timesAssign(other: DoubleTensor) {
//todo should be change with broadcasting
val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] *=
other.buffer.array()[this.bufferStart + i]
newOther.buffer.array()[this.bufferStart + i]
}
}

View File

@ -32,6 +32,43 @@ internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
return totalShape
}
internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
if (tensor.shape.size > newShape.size) {
throw RuntimeException("Tensor is not compatible with the new shape")
}
val n = newShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(newShape, DoubleArray(n))
for (i in tensor.shape.indices) {
val curDim = tensor.shape[i]
val offset = newShape.size - tensor.shape.size
if (curDim != 1 && newShape[i + offset] != curDim) {
throw RuntimeException("Tensor is not compatible with the new shape and cannot be broadcast")
}
}
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
}
return resTensor
}
internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }

View File

@ -58,6 +58,16 @@ class TestDoubleTensorAlgebra {
) contentEquals intArrayOf(5, 6, 7))
}
@Test
fun broadcastTo() = DoubleTensorAlgebra {
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals intArrayOf(2, 3))
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
}
@Test
fun broadcastTensors() = DoubleTensorAlgebra {
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))