KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
3 changed files with 64 additions and 13 deletions
Showing only changes of commit 138d376289 - Show all commits

View File

@ -69,10 +69,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun DoubleTensor.plusAssign(other: DoubleTensor) { override fun DoubleTensor.plusAssign(other: DoubleTensor) {
//todo should be change with broadcasting val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] += this.buffer.array()[this.bufferStart + i] +=
other.buffer.array()[this.bufferStart + i] newOther.buffer.array()[this.bufferStart + i]
} }
} }
@ -107,41 +107,45 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
} }
override fun DoubleTensor.minusAssign(other: DoubleTensor) { override fun DoubleTensor.minusAssign(other: DoubleTensor) {
TODO("Alya") val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] -=
newOther.buffer.array()[this.bufferStart + i]
}
} }
override fun Double.times(other: DoubleTensor): DoubleTensor { override fun Double.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(other.strides.linearSize) { i -> val resBuffer = DoubleArray(other.strides.linearSize) { i ->
other.buffer.array()[other.bufferStart + i] * this other.buffer.array()[other.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
//todo should be change with broadcasting
override fun DoubleTensor.times(value: Double): DoubleTensor = value * this override fun DoubleTensor.times(value: Double): DoubleTensor = value * this
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor { override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
//todo should be change with broadcasting val broadcast = broadcastTensors(this, other)
val resBuffer = DoubleArray(this.strides.linearSize) { i -> val newThis = broadcast[0]
this.buffer.array()[other.bufferStart + i] * val newOther = broadcast[1]
other.buffer.array()[other.bufferStart + i]
val resBuffer = DoubleArray(newThis.strides.linearSize) { i ->
newThis.buffer.array()[newOther.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i]
} }
return DoubleTensor(this.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun DoubleTensor.timesAssign(value: Double) { override fun DoubleTensor.timesAssign(value: Double) {
//todo should be change with broadcasting
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] *= value this.buffer.array()[this.bufferStart + i] *= value
} }
} }
override fun DoubleTensor.timesAssign(other: DoubleTensor) { override fun DoubleTensor.timesAssign(other: DoubleTensor) {
//todo should be change with broadcasting val newOther = broadcastTo(other, this.shape)
for (i in 0 until this.strides.linearSize) { for (i in 0 until this.strides.linearSize) {
this.buffer.array()[this.bufferStart + i] *= this.buffer.array()[this.bufferStart + i] *=
other.buffer.array()[this.bufferStart + i] newOther.buffer.array()[this.bufferStart + i]
} }
} }

View File

@ -32,6 +32,43 @@ internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
return totalShape return totalShape
} }
internal inline fun broadcastTo(tensor: DoubleTensor, newShape: IntArray): DoubleTensor {
if (tensor.shape.size > newShape.size) {
throw RuntimeException("Tensor is not compatible with the new shape")
}
val n = newShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(newShape, DoubleArray(n))
for (i in tensor.shape.indices) {
val curDim = tensor.shape[i]
val offset = newShape.size - tensor.shape.size
if (curDim != 1 && newShape[i + offset] != curDim) {
throw RuntimeException("Tensor is not compatible with the new shape and cannot be broadcast")
}
}
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.array()[linearIndex] =
tensor.buffer.array()[tensor.bufferStart + curLinearIndex]
}
return resTensor
}
internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> { internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray()) val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i } val n = totalShape.reduce { acc, i -> acc * i }

View File

@ -58,6 +58,16 @@ class TestDoubleTensorAlgebra {
) contentEquals intArrayOf(5, 6, 7)) ) contentEquals intArrayOf(5, 6, 7))
} }
@Test
fun broadcastTo() = DoubleTensorAlgebra {
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = DoubleTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape contentEquals intArrayOf(2, 3))
assertTrue(res.buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
}
@Test @Test
fun broadcastTensors() = DoubleTensorAlgebra { fun broadcastTensors() = DoubleTensorAlgebra {
val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor1 = DoubleTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))