Add broadcast #231
@ -3,6 +3,7 @@ package space.kscience.kmath.tensors
|
|||||||
import space.kscience.kmath.nd.MutableNDBuffer
|
import space.kscience.kmath.nd.MutableNDBuffer
|
||||||
import space.kscience.kmath.structures.RealBuffer
|
import space.kscience.kmath.structures.RealBuffer
|
||||||
import space.kscience.kmath.structures.array
|
import space.kscience.kmath.structures.array
|
||||||
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
public class RealTensor(
|
public class RealTensor(
|
||||||
@ -54,27 +55,85 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||||
|
var totalDim = 0
|
||||||
|
for (shape in shapes) {
|
||||||
|
totalDim = max(totalDim, shape.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
val totalShape = IntArray(totalDim) {0}
|
||||||
|
for (shape in shapes) {
|
||||||
|
for (i in shape.indices) {
|
||||||
|
val curDim = shape[i]
|
||||||
|
val offset = totalDim - shape.size
|
||||||
|
totalShape[i + offset] = max(totalShape[i + offset], curDim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (shape in shapes) {
|
||||||
|
for (i in shape.indices) {
|
||||||
|
val curDim = shape[i]
|
||||||
|
val offset = totalDim - shape.size
|
||||||
|
if (curDim != 1 && totalShape[i + offset] != curDim) {
|
||||||
|
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalShape
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
|
||||||
|
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
|
||||||
|
val n = totalShape.reduce{ acc, i -> acc * i }
|
||||||
|
|
||||||
|
val res = ArrayList<RealTensor>(0)
|
||||||
|
for (tensor in tensors) {
|
||||||
|
val resTensor = RealTensor(totalShape, DoubleArray(n))
|
||||||
|
|
||||||
|
for (linearIndex in 0 until n) {
|
||||||
|
val totalMultiIndex = resTensor.strides.index(linearIndex)
|
||||||
|
val curMultiIndex = tensor.shape.copyOf()
|
||||||
|
|
||||||
|
val offset = totalMultiIndex.size - curMultiIndex.size
|
||||||
|
|
||||||
|
for (i in curMultiIndex.indices) {
|
||||||
|
if (curMultiIndex[i] != 1) {
|
||||||
|
curMultiIndex[i] = totalMultiIndex[i + offset]
|
||||||
|
} else {
|
||||||
|
curMultiIndex[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val curLinearIndex = tensor.strides.offset(curMultiIndex)
|
||||||
|
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
|
||||||
|
}
|
||||||
|
res.add(resTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
override fun Double.plus(other: RealTensor): RealTensor {
|
override fun Double.plus(other: RealTensor): RealTensor {
|
||||||
//todo should be change with broadcasting
|
|
||||||
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||||
other.buffer.array[i] + this
|
other.buffer.array[i] + this
|
||||||
}
|
}
|
||||||
return RealTensor(other.shape, resBuffer)
|
return RealTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
//todo should be change with broadcasting
|
|
||||||
override fun RealTensor.plus(value: Double): RealTensor = value + this
|
override fun RealTensor.plus(value: Double): RealTensor = value + this
|
||||||
|
|
||||||
override fun RealTensor.plus(other: RealTensor): RealTensor {
|
override fun RealTensor.plus(other: RealTensor): RealTensor {
|
||||||
//todo should be change with broadcasting
|
val broadcast = broadcastTensors(this, other)
|
||||||
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
val newThis = broadcast[0]
|
||||||
this.buffer.array[i] + other.buffer.array[i]
|
val newOther = broadcast[1]
|
||||||
|
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
||||||
|
newThis.buffer.array[i] + newOther.buffer.array[i]
|
||||||
}
|
}
|
||||||
return RealTensor(this.shape, resBuffer)
|
return RealTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.plusAssign(value: Double) {
|
override fun RealTensor.plusAssign(value: Double) {
|
||||||
//todo should be change with broadcasting
|
|
||||||
for (i in this.buffer.array.indices) {
|
for (i in this.buffer.array.indices) {
|
||||||
this.buffer.array[i] += value
|
this.buffer.array[i] += value
|
||||||
}
|
}
|
||||||
@ -88,19 +147,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun Double.minus(other: RealTensor): RealTensor {
|
override fun Double.minus(other: RealTensor): RealTensor {
|
||||||
TODO("Alya")
|
val resBuffer = DoubleArray(other.buffer.size) { i ->
|
||||||
|
this - other.buffer.array[i]
|
||||||
|
}
|
||||||
|
return RealTensor(other.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.minus(value: Double): RealTensor {
|
override fun RealTensor.minus(value: Double): RealTensor {
|
||||||
TODO("Alya")
|
val resBuffer = DoubleArray(this.buffer.size) { i ->
|
||||||
|
this.buffer.array[i] - value
|
||||||
|
}
|
||||||
|
return RealTensor(this.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.minus(other: RealTensor): RealTensor {
|
override fun RealTensor.minus(other: RealTensor): RealTensor {
|
||||||
TODO("Alya")
|
val broadcast = broadcastTensors(this, other)
|
||||||
|
val newThis = broadcast[0]
|
||||||
|
val newOther = broadcast[1]
|
||||||
|
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
|
||||||
|
newThis.buffer.array[i] - newOther.buffer.array[i]
|
||||||
|
}
|
||||||
|
return RealTensor(newThis.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.minusAssign(value: Double) {
|
override fun RealTensor.minusAssign(value: Double) {
|
||||||
TODO("Alya")
|
for (i in this.buffer.array.indices) {
|
||||||
|
this.buffer.array[i] -= value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.minusAssign(other: RealTensor) {
|
override fun RealTensor.minusAssign(other: RealTensor) {
|
||||||
|
@ -13,6 +13,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
|
|
||||||
public fun TensorType.copy(): TensorType
|
public fun TensorType.copy(): TensorType
|
||||||
|
|
||||||
|
public fun broadcastShapes(vararg shapes: IntArray): IntArray
|
||||||
|
public fun broadcastTensors(vararg tensors: RealTensor): List<TensorType>
|
||||||
|
|
||||||
public operator fun T.plus(other: TensorType): TensorType
|
public operator fun T.plus(other: TensorType): TensorType
|
||||||
public operator fun TensorType.plus(value: T): TensorType
|
public operator fun TensorType.plus(value: T): TensorType
|
||||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||||
|
@ -2,6 +2,8 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
import space.kscience.kmath.structures.array
|
import space.kscience.kmath.structures.array
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertFails
|
||||||
|
import kotlin.test.assertFailsWith
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class TestRealTensorAlgebra {
|
class TestRealTensorAlgebra {
|
||||||
@ -64,4 +66,49 @@ class TestRealTensorAlgebra {
|
|||||||
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
|
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastShapes() = RealTensorAlgebra {
|
||||||
|
assertTrue(this.broadcastShapes(
|
||||||
|
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
|
||||||
|
) contentEquals intArrayOf(1, 2, 3))
|
||||||
|
|
||||||
|
assertTrue(this.broadcastShapes(
|
||||||
|
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
|
||||||
|
) contentEquals intArrayOf(5, 6, 7))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastTensors() = RealTensorAlgebra {
|
||||||
|
val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
|
val res = this.broadcastTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
|
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
|
||||||
|
assertTrue(res[0].buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
assertTrue(res[1].buffer.array contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
|
assertTrue(res[2].buffer.array contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun minusTensor() = RealTensorAlgebra {
|
||||||
|
val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
|
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
||||||
|
assertTrue((tensor2 - tensor1).buffer.array contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||||
|
|
||||||
|
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue((tensor3 - tensor1).buffer.array
|
||||||
|
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
||||||
|
|
||||||
|
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
||||||
|
assertTrue((tensor3 - tensor2).buffer.array contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user