Merge remote-tracking branch 'ups/feature/tensor-algebra' into andrew

This commit is contained in:
Andrei Kislitsyn 2021-03-24 18:08:53 +03:00
commit 96a755071f
4 changed files with 179 additions and 4 deletions

View File

@ -45,7 +45,7 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
val newThis = broadcast[0]
val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.size) { i ->
newThis.buffer.array()[newOther.bufferStart + i] *
newThis.buffer.array()[newThis.bufferStart + i] *
newOther.buffer.array()[newOther.bufferStart + i]
}
return DoubleTensor(newThis.shape, resBuffer)
@ -85,7 +85,6 @@ public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorA
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
println(shapes)
var totalDim = 0
for (shape in shapes) {
totalDim = max(totalDim, shape.size)
@ -181,3 +180,57 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
return res
}
internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
val onlyTwoDims = tensors.asSequence().onEach {
require(it.shape.size >= 2) {
throw RuntimeException("Tensors must have at least 2 dimensions")
}
}.any { it.shape.size != 2 }
if (!onlyTwoDims) {
return tensors.asList()
}
val totalShape = broadcastShapes(*(tensors.map { it.shape.sliceArray(0..it.shape.size - 3) }).toTypedArray())
val n = totalShape.reduce { acc, i -> acc * i }
val res = ArrayList<DoubleTensor>(0)
for (tensor in tensors) {
val matrixShape = tensor.shape.sliceArray(tensor.shape.size - 2 until tensor.shape.size).copyOf()
val matrixSize = matrixShape[0] * matrixShape[1]
val matrix = DoubleTensor(matrixShape, DoubleArray(matrixSize))
val outerTensor = DoubleTensor(totalShape, DoubleArray(n))
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
for (linearIndex in 0 until n) {
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex)
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) {1} + curMultiIndex
val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.buffer.array())
for (i in curMultiIndex.indices) {
if (curMultiIndex[i] != 1) {
curMultiIndex[i] = totalMultiIndex[i]
} else {
curMultiIndex[i] = 0
}
}
for (i in 0 until matrixSize) {
val curLinearIndex = newTensor.linearStructure.offset(curMultiIndex +
matrix.linearStructure.index(i))
val newLinearIndex = resTensor.linearStructure.offset(totalMultiIndex +
matrix.linearStructure.index(i))
resTensor.buffer.array()[resTensor.bufferStart + newLinearIndex] =
newTensor.buffer.array()[newTensor.bufferStart + curLinearIndex]
}
}
res += resTensor
}
return res
}

View File

@ -135,7 +135,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
checkShapesCompatible(this, other)
val resBuffer = DoubleArray(this.linearStructure.size) { i ->
this.buffer.array()[other.bufferStart + i] *
this.buffer.array()[this.bufferStart + i] *
other.buffer.array()[other.bufferStart + i]
}
return DoubleTensor(this.shape, resBuffer)
@ -245,7 +245,63 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
}
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
TODO("Alya")
if (this.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
}
var newThis = this.copy()
var newOther = other.copy()
var penultimateDim = false
var lastDim = false
if (this.shape.size == 1) {
penultimateDim = true
newThis = this.view(intArrayOf(1) + this.shape)
}
if (other.shape.size == 1) {
lastDim = true
newOther = other.view(other.shape + intArrayOf(1) )
}
val broadcastTensors = broadcastOuterTensors(newThis, newOther)
newThis = broadcastTensors[0]
newOther = broadcastTensors[1]
val l = newThis.shape[newThis.shape.size - 2]
val m1= newThis.shape[newThis.shape.size - 1]
val m2 = newOther.shape[newOther.shape.size - 2]
val n = newOther.shape[newOther.shape.size - 1]
if (m1 != m2) {
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
}
val m = m1
val resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
val resSize = resShape.reduce { acc, i -> acc * i }
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
val (a, b) = ab
for (i in 0 until l) {
for (j in 0 until n) {
var curr = 0.0
for (k in 0 until m) {
curr += a[i, k] * b[k, j]
}
res[i, j] = curr
}
}
}
if (penultimateDim) {
return resTensor.view(resTensor.shape.dropLast(2).toIntArray() +
intArrayOf(resTensor.shape.last()))
}
if (lastDim) {
return resTensor.view(resTensor.shape.dropLast(1).toIntArray())
}
return resTensor
}
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {

View File

@ -47,6 +47,36 @@ class TestBroadcasting {
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
fun broadcastOuterTensors() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1))
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0))
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0))
}
@Test
fun broadcastOuterTensorsShapes() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) {0.0})
val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) {0.0})
val tensor3 = fromArray(intArrayOf(1, 1), doubleArrayOf(500.0))
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertTrue(res[0].shape contentEquals intArrayOf(4, 2, 5, 3, 2, 3))
assertTrue(res[1].shape contentEquals intArrayOf(4, 2, 5, 3, 3, 3))
assertTrue(res[2].shape contentEquals intArrayOf(4, 2, 5, 3, 1, 1))
}
@Test
fun minusTensor() = BroadcastDoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))

View File

@ -79,6 +79,42 @@ class TestDoubleTensorAlgebra {
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array())
}
@Test
fun dot() = DoubleTensorAlgebra {
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
val res12 = tensor1.dot(tensor2)
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals intArrayOf(2))
val res32 = tensor3.dot(tensor2)
assertTrue(res32.buffer.array() contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals intArrayOf(1, 1))
val res22 = tensor2.dot(tensor2)
assertTrue(res22.buffer.array() contentEquals doubleArrayOf(1400.0))
assertTrue(res22.shape contentEquals intArrayOf(1))
val res11 = tensor1.dot(tensor11)
assertTrue(res11.buffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
var tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
var tensor5 = fromArray(intArrayOf(10, 4, 5), DoubleArray(10 * 4 * 5) {0.0})
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(10, 3, 5))
tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
tensor5 = fromArray(intArrayOf(4, 5), DoubleArray(4 * 5) {0.0})
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(10, 3, 5))
tensor4 = fromArray(intArrayOf(4, 2, 1, 3, 8, 1), DoubleArray(4 * 2 * 1 * 3 * 8 * 1) {0.0})
tensor5 = fromArray(intArrayOf(5, 1, 2, 8, 3, 1, 5), DoubleArray(5 * 1 * 2 * 8 * 3 * 1 * 5) {0.0})
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(5, 4, 2, 8, 3, 8, 5))
}
@Test
fun testContentEqual() = DoubleTensorAlgebra {
//TODO()