Implement much faster dot product algorithm for tensors

This commit is contained in:
Ivan Kylchik 2022-02-13 16:01:05 +03:00 committed by Iaroslav Postovalov
parent 8974164ec0
commit a78e361b17
4 changed files with 33 additions and 6 deletions

View File

@ -15,7 +15,9 @@ import space.kscience.kmath.linear.invoke
import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import kotlin.random.Random
@State(Scope.Benchmark)
@ -32,6 +34,9 @@ internal class DotBenchmark {
random.nextDouble()
}
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
val cmMatrix1 = CMLinearSpace { matrix1.toCM() }
val cmMatrix2 = CMLinearSpace { matrix2.toCM() }
@ -78,4 +83,9 @@ internal class DotBenchmark {
fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) {
blackhole.consume(matrix1 dot matrix2)
}
@Benchmark
fun doubleTensorDot(blackhole: Blackhole) = DoubleTensorAlgebra.invoke {
blackhole.consume(tensor1 dot tensor2)
}
}

View File

@ -421,7 +421,7 @@ public open class DoubleTensorAlgebra :
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
val (a, b) = ab
dotTo(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
dotTo(a, b, res, l, m1, n)
}
return if (penultimateDim) {

View File

@ -54,18 +54,26 @@ internal val <T> BufferedTensor<T>.matrices: VirtualBuffer<BufferedTensor<T>>
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = matrices.asSequence()
internal fun dotTo(
a: MutableStructure2D<Double>,
b: MutableStructure2D<Double>,
res: MutableStructure2D<Double>,
a: BufferedTensor<Double>,
b: BufferedTensor<Double>,
res: BufferedTensor<Double>,
l: Int, m: Int, n: Int,
) {
val aStart = a.bufferStart
val bStart = b.bufferStart
val resStart = res.bufferStart
val aBuffer = a.mutableBuffer
val bBuffer = b.mutableBuffer
val resBuffer = res.mutableBuffer
for (i in 0 until l) {
for (j in 0 until n) {
var curr = 0.0
for (k in 0 until m) {
curr += a[i, k] * b[k, j]
curr += aBuffer[aStart + i * m + k] * bBuffer[bStart + k * n + j]
}
res[i, j] = curr
resBuffer[resStart + i * n + j] = curr
}
}
}

View File

@ -107,6 +107,8 @@ internal class TestDoubleTensorAlgebra {
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
val tensor4 = fromArray(intArrayOf(2, 3, 3), (1..18).map { it.toDouble() }.toDoubleArray())
val tensor5 = fromArray(intArrayOf(2, 3, 3), (1..18).map { 1 + it.toDouble() }.toDoubleArray())
val res12 = tensor1.dot(tensor2)
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
@ -123,6 +125,13 @@ internal class TestDoubleTensorAlgebra {
val res11 = tensor1.dot(tensor11)
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
val res45 = tensor4.dot(tensor5)
assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf(
36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0,
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
))
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
}
@Test