forked from kscience/kmath
Implement much faster dot product algorithm for tensors
This commit is contained in:
parent
8974164ec0
commit
a78e361b17
@ -15,7 +15,9 @@ import space.kscience.kmath.linear.invoke
|
||||
import space.kscience.kmath.linear.linearSpace
|
||||
import space.kscience.kmath.multik.multikAlgebra
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import kotlin.random.Random
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
@ -32,6 +34,9 @@ internal class DotBenchmark {
|
||||
random.nextDouble()
|
||||
}
|
||||
|
||||
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
|
||||
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
|
||||
|
||||
val cmMatrix1 = CMLinearSpace { matrix1.toCM() }
|
||||
val cmMatrix2 = CMLinearSpace { matrix2.toCM() }
|
||||
|
||||
@ -78,4 +83,9 @@ internal class DotBenchmark {
|
||||
fun doubleDot(blackhole: Blackhole) = with(DoubleField.linearSpace) {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun doubleTensorDot(blackhole: Blackhole) = DoubleTensorAlgebra.invoke {
|
||||
blackhole.consume(tensor1 dot tensor2)
|
||||
}
|
||||
}
|
||||
|
@ -421,7 +421,7 @@ public open class DoubleTensorAlgebra :
|
||||
|
||||
for ((res, ab) in resTensor.matrixSequence().zip(newThis.matrixSequence().zip(newOther.matrixSequence()))) {
|
||||
val (a, b) = ab
|
||||
dotTo(a.as2D(), b.as2D(), res.as2D(), l, m1, n)
|
||||
dotTo(a, b, res, l, m1, n)
|
||||
}
|
||||
|
||||
return if (penultimateDim) {
|
||||
|
@ -54,18 +54,26 @@ internal val <T> BufferedTensor<T>.matrices: VirtualBuffer<BufferedTensor<T>>
|
||||
internal fun <T> BufferedTensor<T>.matrixSequence(): Sequence<BufferedTensor<T>> = matrices.asSequence()
|
||||
|
||||
internal fun dotTo(
|
||||
a: MutableStructure2D<Double>,
|
||||
b: MutableStructure2D<Double>,
|
||||
res: MutableStructure2D<Double>,
|
||||
a: BufferedTensor<Double>,
|
||||
b: BufferedTensor<Double>,
|
||||
res: BufferedTensor<Double>,
|
||||
l: Int, m: Int, n: Int,
|
||||
) {
|
||||
val aStart = a.bufferStart
|
||||
val bStart = b.bufferStart
|
||||
val resStart = res.bufferStart
|
||||
|
||||
val aBuffer = a.mutableBuffer
|
||||
val bBuffer = b.mutableBuffer
|
||||
val resBuffer = res.mutableBuffer
|
||||
|
||||
for (i in 0 until l) {
|
||||
for (j in 0 until n) {
|
||||
var curr = 0.0
|
||||
for (k in 0 until m) {
|
||||
curr += a[i, k] * b[k, j]
|
||||
curr += aBuffer[aStart + i * m + k] * bBuffer[bStart + k * n + j]
|
||||
}
|
||||
res[i, j] = curr
|
||||
resBuffer[resStart + i * n + j] = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -107,6 +107,8 @@ internal class TestDoubleTensorAlgebra {
|
||||
val tensor11 = fromArray(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor2 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor3 = fromArray(intArrayOf(1, 1, 3), doubleArrayOf(-1.0, -2.0, -3.0))
|
||||
val tensor4 = fromArray(intArrayOf(2, 3, 3), (1..18).map { it.toDouble() }.toDoubleArray())
|
||||
val tensor5 = fromArray(intArrayOf(2, 3, 3), (1..18).map { 1 + it.toDouble() }.toDoubleArray())
|
||||
|
||||
val res12 = tensor1.dot(tensor2)
|
||||
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
||||
@ -123,6 +125,13 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res11 = tensor1.dot(tensor11)
|
||||
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
|
||||
|
||||
val res45 = tensor4.dot(tensor5)
|
||||
assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf(
|
||||
36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0,
|
||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||
))
|
||||
assertTrue(res45.shape contentEquals intArrayOf(2, 3, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
|
Loading…
Reference in New Issue
Block a user