resolve mc

This commit is contained in:
Andrei Kislitsyn 2021-03-13 19:32:34 +03:00
commit 3e0d152c1b
3 changed files with 129 additions and 58 deletions

View File

@ -3,6 +3,7 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.structures.RealBuffer
import space.kscience.kmath.structures.array
import kotlin.math.abs
import kotlin.math.max
public class RealTensor(
@ -50,27 +51,85 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
return RealTensor(this.shape, this.buffer.array.copyOf())
}
override fun broadcastShapes(vararg shapes: IntArray): IntArray {
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
}
val totalShape = IntArray(totalDim) {0}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
totalShape[i + offset] = max(totalShape[i + offset], curDim)
}
}
for (shape in shapes) {
for (i in shape.indices) {
val curDim = shape[i]
val offset = totalDim - shape.size
if (curDim != 1 && totalShape[i + offset] != curDim) {
throw RuntimeException("Shapes are not compatible and cannot be broadcast")
}
}
}
return totalShape
}
override fun broadcastTensors(vararg tensors: RealTensor): List<RealTensor> {
val totalShape = broadcastShapes(*(tensors.map { it.shape }).toTypedArray())
val n = totalShape.reduce{ acc, i -> acc * i }
val res = ArrayList<RealTensor>(0)
for (tensor in tensors) {
val resTensor = RealTensor(totalShape, DoubleArray(n))
for (linearIndex in 0 until n) {
val totalMultiIndex = resTensor.strides.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i + offset]
} else {
curMultiIndex[i] = 0
}
}
val curLinearIndex = tensor.strides.offset(curMultiIndex)
resTensor.buffer.array[linearIndex] = tensor.buffer.array[curLinearIndex]
}
res.add(resTensor)
}
return res
}
override fun Double.plus(other: RealTensor): RealTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(other.buffer.size) { i ->
other.buffer.array[i] + this
}
return RealTensor(other.shape, resBuffer)
}
//todo should be change with broadcasting
override fun RealTensor.plus(value: Double): RealTensor = value + this
override fun RealTensor.plus(other: RealTensor): RealTensor {
//todo should be change with broadcasting
val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.array[i] + other.buffer.array[i]
val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.array[i] + newOther.buffer.array[i]
}
return RealTensor(this.shape, resBuffer)
return RealTensor(newThis.shape, resBuffer)
}
override fun RealTensor.plusAssign(value: Double) {
//todo should be change with broadcasting
for (i in this.buffer.array.indices) {
this.buffer.array[i] += value
}
@ -84,19 +143,33 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
}
override fun Double.minus(other: RealTensor): RealTensor {
TODO("Alya")
val resBuffer = DoubleArray(other.buffer.size) { i ->
this - other.buffer.array[i]
}
return RealTensor(other.shape, resBuffer)
}
override fun RealTensor.minus(value: Double): RealTensor {
TODO("Alya")
val resBuffer = DoubleArray(this.buffer.size) { i ->
this.buffer.array[i] - value
}
return RealTensor(this.shape, resBuffer)
}
override fun RealTensor.minus(other: RealTensor): RealTensor {
TODO("Alya")
val broadcast = broadcastTensors(this, other)
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.buffer.size) { i ->
newThis.buffer.array[i] - newOther.buffer.array[i]
}
return RealTensor(newThis.shape, resBuffer)
}
override fun RealTensor.minusAssign(value: Double) {
TODO("Alya")
for (i in this.buffer.array.indices) {
this.buffer.array[i] -= value
}
}
override fun RealTensor.minusAssign(other: RealTensor) {
@ -147,19 +220,12 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Alya")
}
override fun RealTensor.dotAssign(other: RealTensor) {
TODO("Alya")
}
override fun RealTensor.dotRightAssign(other: RealTensor) {
TODO("Alya")
}
override fun diagonalEmbedding(diagonalEntries: RealTensor, offset: Int, dim1: Int, dim2: Int): RealTensor {
TODO("Alya")
}
override fun RealTensor.transpose(i: Int, j: Int): RealTensor {
checkTranspose(this.dimension, i, j)
val n = this.buffer.size
val resBuffer = DoubleArray(n)
@ -179,15 +245,6 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
return resTensor
}
override fun RealTensor.transposeAssign(i: Int, j: Int) {
val transposedTensor = this.transpose(i, j)
for (i in transposedTensor.shape.indices) {
this.shape[i] = transposedTensor.shape[i]
}
for (i in transposedTensor.buffer.array.indices) {
this.buffer.array[i] = transposedTensor.buffer.array[i]
}
}
override fun RealTensor.view(shape: IntArray): RealTensor {
return RealTensor(shape, this.buffer.array)
@ -201,17 +258,12 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Not yet implemented")
}
override fun RealTensor.absAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.sum(): RealTensor {
TODO("Not yet implemented")
}
override fun RealTensor.sumAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.div(value: Double): RealTensor {
TODO("Not yet implemented")
@ -233,17 +285,10 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
TODO("Not yet implemented")
}
override fun RealTensor.expAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.log(): RealTensor {
TODO("Not yet implemented")
}
override fun RealTensor.logAssign() {
TODO("Not yet implemented")
}
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
// todo checks

View File

@ -13,6 +13,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.copy(): TensorType
public fun broadcastShapes(vararg shapes: IntArray): IntArray
public fun broadcastTensors(vararg tensors: RealTensor): List<TensorType>
public operator fun T.plus(other: TensorType): TensorType
public operator fun TensorType.plus(value: T): TensorType
public operator fun TensorType.plus(other: TensorType): TensorType
@ -35,8 +38,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.matmul.html
public infix fun TensorType.dot(other: TensorType): TensorType
public infix fun TensorType.dotAssign(other: TensorType): Unit
public infix fun TensorType.dotRightAssign(other: TensorType): Unit
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding(
@ -46,7 +47,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.transpose.html
public fun TensorType.transpose(i: Int, j: Int): TensorType
public fun TensorType.transposeAssign(i: Int, j: Int): Unit
//https://pytorch.org/docs/stable/tensor_view.html
public fun TensorType.view(shape: IntArray): TensorType
@ -54,11 +54,9 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.abs.html
public fun TensorType.abs(): TensorType
public fun TensorType.absAssign(): Unit
//https://pytorch.org/docs/stable/generated/torch.sum.html
public fun TensorType.sum(): TensorType
public fun TensorType.sumAssign(): Unit
}
// https://proofwiki.org/wiki/Definition:Division_Algebra
@ -72,11 +70,9 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
//https://pytorch.org/docs/stable/generated/torch.exp.html
public fun TensorType.exp(): TensorType
public fun TensorType.expAssign(): Unit
//https://pytorch.org/docs/stable/generated/torch.log.html
public fun TensorType.log(): TensorType
public fun TensorType.logAssign(): Unit
// todo change type of pivots
//https://pytorch.org/docs/stable/generated/torch.lu.html

View File

@ -2,6 +2,8 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.structures.array
import kotlin.test.Test
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class TestRealTensorAlgebra {
@ -48,20 +50,48 @@ class TestRealTensorAlgebra {
}
@Test
fun transposeAssign1x2() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(1,2), doubleArrayOf(1.0, 2.0))
tensor.transposeAssign(0, 1)
fun broadcastShapes() = RealTensorAlgebra {
assertTrue(this.broadcastShapes(
intArrayOf(2, 3), intArrayOf(1, 3), intArrayOf(1, 1, 1)
) contentEquals intArrayOf(1, 2, 3))
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 2.0))
assertTrue(tensor.shape contentEquals intArrayOf(2, 1))
assertTrue(this.broadcastShapes(
intArrayOf(6, 7), intArrayOf(5, 6, 1), intArrayOf(7,), intArrayOf(5, 1, 7)
) contentEquals intArrayOf(5, 6, 7))
}
@Test
fun transposeAssign2x3() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
tensor.transposeAssign(1, 0)
fun broadcastTensors() = RealTensorAlgebra {
val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue(tensor.buffer.array contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(tensor.shape contentEquals intArrayOf(3, 2))
val res = this.broadcastTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.array contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.array contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.array contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
fun minusTensor() = RealTensorAlgebra {
val tensor1 = RealTensor(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = RealTensor(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.array contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.array
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.array contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}
}