add broadcast of all dims except the last 2, add tensors dot, fix bug in function times
This commit is contained in:
parent
0365d41f31
commit
2d2c4bd474
@ -85,7 +85,6 @@ public inline fun <R> BroadcastDoubleTensorAlgebra(block: BroadcastDoubleTensorA
|
|||||||
|
|
||||||
|
|
||||||
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
internal inline fun broadcastShapes(vararg shapes: IntArray): IntArray {
|
||||||
println(shapes)
|
|
||||||
var totalDim = 0
|
var totalDim = 0
|
||||||
for (shape in shapes) {
|
for (shape in shapes) {
|
||||||
totalDim = max(totalDim, shape.size)
|
totalDim = max(totalDim, shape.size)
|
||||||
@ -181,3 +180,61 @@ internal inline fun broadcastTensors(vararg tensors: DoubleTensor): List<DoubleT
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTensor> {
|
||||||
|
var onlyTwoDims = true
|
||||||
|
for (tensor in tensors) {
|
||||||
|
if (tensor.shape.size < 2) {
|
||||||
|
throw RuntimeException("Tensors must have at least 2 dimensions")
|
||||||
|
}
|
||||||
|
if (tensor.shape.size != 2) {
|
||||||
|
onlyTwoDims = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (onlyTwoDims) {
|
||||||
|
return tensors.asList()
|
||||||
|
}
|
||||||
|
|
||||||
|
val totalShape = broadcastShapes(*(tensors.map { it.shape.sliceArray(0..it.shape.size - 3) }).toTypedArray())
|
||||||
|
val n = totalShape.reduce { acc, i -> acc * i }
|
||||||
|
|
||||||
|
val res = ArrayList<DoubleTensor>(0)
|
||||||
|
for (tensor in tensors) {
|
||||||
|
val matrixShape = tensor.shape.sliceArray(tensor.shape.size - 2 until tensor.shape.size).copyOf()
|
||||||
|
val matrixSize = matrixShape[0] * matrixShape[1]
|
||||||
|
val matrix = DoubleTensor(matrixShape, DoubleArray(matrixSize))
|
||||||
|
|
||||||
|
val outerTensor = DoubleTensor(totalShape, DoubleArray(n))
|
||||||
|
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
|
||||||
|
|
||||||
|
for (linearIndex in 0 until n) {
|
||||||
|
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex)
|
||||||
|
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
|
||||||
|
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) {1} + curMultiIndex
|
||||||
|
|
||||||
|
val newTensor = DoubleTensor(curMultiIndex + matrixShape, tensor.buffer.array())
|
||||||
|
|
||||||
|
for (i in curMultiIndex.indices) {
|
||||||
|
if (curMultiIndex[i] != 1) {
|
||||||
|
curMultiIndex[i] = totalMultiIndex[i]
|
||||||
|
} else {
|
||||||
|
curMultiIndex[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i in 0 until matrixSize) {
|
||||||
|
val curLinearIndex = newTensor.linearStructure.offset(curMultiIndex +
|
||||||
|
matrix.linearStructure.index(i))
|
||||||
|
val newLinearIndex = resTensor.linearStructure.offset(totalMultiIndex +
|
||||||
|
matrix.linearStructure.index(i))
|
||||||
|
|
||||||
|
resTensor.buffer.array()[resTensor.bufferStart + newLinearIndex] =
|
||||||
|
newTensor.buffer.array()[newTensor.bufferStart + curLinearIndex]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res.add(resTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
@ -132,7 +132,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor {
|
||||||
checkShapesCompatible(this, other)
|
checkShapesCompatible(this, other)
|
||||||
val resBuffer = DoubleArray(this.linearStructure.size) { i ->
|
val resBuffer = DoubleArray(this.linearStructure.size) { i ->
|
||||||
this.buffer.array()[other.bufferStart + i] *
|
this.buffer.array()[this.bufferStart + i] *
|
||||||
other.buffer.array()[other.bufferStart + i]
|
other.buffer.array()[other.bufferStart + i]
|
||||||
}
|
}
|
||||||
return DoubleTensor(this.shape, resBuffer)
|
return DoubleTensor(this.shape, resBuffer)
|
||||||
@ -241,8 +241,84 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun DoubleTensor.dotTwoDimensionalTensors(other: DoubleTensor): DoubleTensor {
|
||||||
|
if (this.shape.size > 2 || other.shape.size > 2) {
|
||||||
|
throw RuntimeException("Both tensors must have a maximum of 2 dimensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.shape[1] != other.shape[0]) {
|
||||||
|
throw RuntimeException("Tensors dot operation dimension mismatch: " +
|
||||||
|
"(${this.shape[0]}, ${this.shape[1]}) x (${other.shape[0]}, ${other.shape[1]})")
|
||||||
|
}
|
||||||
|
|
||||||
|
val l = this.shape[0]
|
||||||
|
val m = this.shape[1]
|
||||||
|
val n = other.shape[1]
|
||||||
|
|
||||||
|
val res = DoubleTensor(intArrayOf(l, n), DoubleArray(l * n))
|
||||||
|
|
||||||
|
for (i in 0 until l) {
|
||||||
|
for (j in 0 until n) {
|
||||||
|
var curr = 0.0
|
||||||
|
for (k in 0 until m) {
|
||||||
|
val ik = this.linearStructure.offset(intArrayOf(i, k))
|
||||||
|
val kj = other.linearStructure.offset(intArrayOf(k, j))
|
||||||
|
curr += this.buffer.array()[ik] * other.buffer.array()[kj]
|
||||||
|
}
|
||||||
|
val linearIndex = res.linearStructure.offset(intArrayOf(i, j))
|
||||||
|
res.buffer.array()[linearIndex] = curr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
override fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor {
|
||||||
TODO("Alya")
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
|
return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum()))
|
||||||
|
}
|
||||||
|
|
||||||
|
var newThis = this.copy()
|
||||||
|
var newOther = other.copy()
|
||||||
|
if (this.shape.size == 1) {
|
||||||
|
newThis = this.view(intArrayOf(1) + this.shape)
|
||||||
|
}
|
||||||
|
if (other.shape.size == 1) {
|
||||||
|
newOther = other.view(other.shape + intArrayOf(1) )
|
||||||
|
}
|
||||||
|
|
||||||
|
val broadcastTensors = broadcastOuterTensors(newThis, newOther)
|
||||||
|
newThis = broadcastTensors[0]
|
||||||
|
newOther = broadcastTensors[1]
|
||||||
|
|
||||||
|
val l = newThis.shape[newThis.shape.size - 2]
|
||||||
|
val m1= newThis.shape[newThis.shape.size - 1]
|
||||||
|
val m2 = newOther.shape[newOther.shape.size - 2]
|
||||||
|
val n = newOther.shape[newOther.shape.size - 1]
|
||||||
|
if (m1 != m2) {
|
||||||
|
throw RuntimeException("Tensors dot operation dimension mismatch: ($l, $m1) x ($m2, $n)")
|
||||||
|
}
|
||||||
|
val m = m1
|
||||||
|
|
||||||
|
var resShape = newThis.shape.sliceArray(0..(newThis.shape.size - 2)) + intArrayOf(newOther.shape.last())
|
||||||
|
val resSize = resShape.reduce { acc, i -> acc * i }
|
||||||
|
val resTensor = DoubleTensor(resShape, DoubleArray(resSize))
|
||||||
|
|
||||||
|
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
||||||
|
val a = ab.first
|
||||||
|
val b = ab.second
|
||||||
|
|
||||||
|
for (i in 0 until l) {
|
||||||
|
for (j in 0 until n) {
|
||||||
|
var curr = 0.0
|
||||||
|
for (k in 0 until m) {
|
||||||
|
curr += a[i, k] * b[k, j]
|
||||||
|
}
|
||||||
|
res[i, j] = curr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||||
|
@ -47,6 +47,36 @@ class TestBroadcasting {
|
|||||||
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastOuterTensors() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = fromArray(intArrayOf(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor3 = fromArray(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
|
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
|
assertTrue(res[0].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
assertTrue(res[1].shape contentEquals intArrayOf(1, 1, 3))
|
||||||
|
assertTrue(res[2].shape contentEquals intArrayOf(1, 1, 1))
|
||||||
|
|
||||||
|
assertTrue(res[0].buffer.array() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
assertTrue(res[1].buffer.array() contentEquals doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
assertTrue(res[2].buffer.array() contentEquals doubleArrayOf(500.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun broadcastOuterTensorsShapes() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(2, 1, 3, 2, 3), DoubleArray(2 * 1 * 3 * 2 * 3) {0.0})
|
||||||
|
val tensor2 = fromArray(intArrayOf(4, 2, 5, 1, 3, 3), DoubleArray(4 * 2 * 5 * 1 * 3 * 3) {0.0})
|
||||||
|
val tensor3 = fromArray(intArrayOf(1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
|
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||||
|
|
||||||
|
assertTrue(res[0].shape contentEquals intArrayOf(4, 2, 5, 3, 2, 3))
|
||||||
|
assertTrue(res[1].shape contentEquals intArrayOf(4, 2, 5, 3, 3, 3))
|
||||||
|
assertTrue(res[2].shape contentEquals intArrayOf(4, 2, 5, 3, 1, 1))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun minusTensor() = BroadcastDoubleTensorAlgebra {
|
fun minusTensor() = BroadcastDoubleTensorAlgebra {
|
||||||
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
@ -79,6 +79,20 @@ class TestDoubleTensorAlgebra {
|
|||||||
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array())
|
assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun dot() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val res12 = tensor1.dot(tensor2)
|
||||||
|
|
||||||
|
assertTrue(res12.buffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
||||||
|
assertTrue(res12.shape contentEquals intArrayOf(2, 1))
|
||||||
|
|
||||||
|
val tensor4 = fromArray(intArrayOf(10, 3, 4), DoubleArray(10 * 3 * 4) {0.0})
|
||||||
|
val tensor5 = fromArray(intArrayOf(10, 4, 5), DoubleArray(10 * 4 * 5) {0.0})
|
||||||
|
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(10, 3, 5))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testContentEqual() = DoubleTensorAlgebra {
|
fun testContentEqual() = DoubleTensorAlgebra {
|
||||||
//TODO()
|
//TODO()
|
||||||
|
Loading…
Reference in New Issue
Block a user