Merge PR and check transpose

This commit is contained in:
Roland Grinis 2021-03-12 15:11:33 +00:00
commit 03cc6a310b
3 changed files with 136 additions and 11 deletions

View File

@ -3,6 +3,7 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.nd.MutableNDBuffer import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.RealBuffer
import space.kscience.kmath.structures.array import space.kscience.kmath.structures.array
import kotlin.math.max
public class RealTensor( public class RealTensor(
@ -54,27 +55,85 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
}
val totalShape = IntArray(totalDim) {0}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
totalShape[i + offset] = max(totalShape[i + offset], curDim)
}
}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
if (curDim != 1 && totalShape[i + offset] != curDim) {
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
}
}
}
return totalShape
}
override fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce{ acc, i -> acc * i }
val res = ArrayList<RealTensor>(0)
for (tensor in tensors) {
val resTensor = RealTensor(totalShape, DoubleArray(n))
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
}
res.add(resTensor)
}
return res
}
override fun Double.plus(other: RealTensor): RealTensor { override fun Double.plus(other: RealTensor): RealTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(other.buffer.size) { i -> val resBuffer = DoubleArray(other.buffer.size) { i ->
other.buffer.array[i] + this other.buffer.array[i] + this
} }
return RealTensor(other.shape, resBuffer) return RealTensor(other.shape, resBuffer)
} }
//todo should be change with broadcasting
override fun RealTensor.plus(value: Double): RealTensor = value + this override fun RealTensor.plus(value: Double): RealTensor = value + this
override fun RealTensor.plus(other: RealTensor): RealTensor { override fun RealTensor.plus(other: RealTensor): RealTensor {
//todo should be change with broadcasting val broadcast = broadcastTensors(this, other)
val resBuffer = DoubleArray(this.buffer.size) { i -> val newThis = broadcast[0]
this.buffer.array[i] + other.buffer.array[i] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.array[i] + newOther.buffer.array[i]
} }
return RealTensor(this.shape, resBuffer) return RealTensor(newThis.shape, resBuffer)
} }
override fun RealTensor.plusAssign(value: Double) { override fun RealTensor.plusAssign(value: Double) {
//todo should be change with broadcasting
for (i in this.buffer.array.indices) { for (i in this.buffer.array.indices) {
this.buffer.array[i] += value this.buffer.array[i] += value
} }
@ -88,19 +147,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun Double.minus(other: RealTensor): RealTensor { override fun Double.minus(other: RealTensor): RealTensor {
TODO("Alya") val resBuffer = DoubleArray(other.buffer.size) { i ->
this - other.buffer.array[i]
}
return RealTensor(other.shape, resBuffer)
} }
override fun RealTensor.minus(value: Double): RealTensor { override fun RealTensor.minus(value: Double): RealTensor {
TODO("Alya") val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.array[i] - value
}
return RealTensor(this.shape, resBuffer)
} }
override fun RealTensor.minus(other: RealTensor): RealTensor { override fun RealTensor.minus(other: RealTensor): RealTensor {
TODO("Alya") val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.array[i] - newOther.buffer.array[i]
}
return RealTensor(newThis.shape, resBuffer)
} }
override fun RealTensor.minusAssign(value: Double) { override fun RealTensor.minusAssign(value: Double) {
TODO("Alya") for (i in this.buffer.array.indices) {
this.buffer.array[i] -= value
}
} }
override fun RealTensor.minusAssign(other: RealTensor) { override fun RealTensor.minusAssign(other: RealTensor) {
@ -156,6 +229,7 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
} }
override fun RealTensor.transpose(i: Int, j: Int): RealTensor { override fun RealTensor.transpose(i: Int, j: Int): RealTensor {
checkTranspose(this.dimension, i, j)
val n = this.buffer.size val n = this.buffer.size
val resBuffer = DoubleArray(n) val resBuffer = DoubleArray(n)

View File

@ -13,6 +13,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.copy(): TensorType public fun TensorType.copy(): TensorType
public fun broadcastShapes(vararg shapes: IntArray): IntArray
public fun broadcastTensors(vararg tensors: RealTensor): List<TensorType>
public operator fun T.plus(other: TensorType): TensorType public operator fun T.plus(other: TensorType): TensorType
public operator fun TensorType.plus(value: T): TensorType public operator fun TensorType.plus(value: T): TensorType
public operator fun TensorType.plus(other: TensorType): TensorType public operator fun TensorType.plus(other: TensorType): TensorType

View File

@ -2,6 +2,8 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.structures.array import space.kscience.kmath.structures.array
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue import kotlin.test.assertTrue
class TestRealTensorAlgebra { class TestRealTensorAlgebra {
@ -46,4 +48,50 @@ class TestRealTensorAlgebra {
assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test
fun broadcastShapes() = RealTensorAlgebra {
assertTrue(this.broadcastShapes(
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
) contentEquals intArrayOf(1, 2, 3))
assertTrue(this.broadcastShapes(
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
) contentEquals intArrayOf(5, 6, 7))
}
@Test
fun broadcastTensors() = RealTensorAlgebra {
val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = this.broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.array contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.array contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
fun minusTensor() = RealTensorAlgebra {
val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.array contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.array
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.array contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}
} }