Benchmarks for tensor matrix multiplication over Double

This commit is contained in:
rgrit91 2021-01-10 14:00:35 +00:00
parent 0cb2c3f0da
commit 7894799e8e
3 changed files with 37 additions and 3 deletions

View File

@ -163,6 +163,5 @@ TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
{ {
auto lhs_tensor = ctorch::cast(lhs); ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
lhs_tensor = lhs_tensor.matmul(ctorch::cast(rhs));
} }

View File

@ -0,0 +1,32 @@
package kscience.kmath.torch
import kotlin.test.Test
import kotlin.time.measureTime
internal fun benchmarkingDoubleMatrixMultiplication(
scale: Int,
numIter: Int,
device: TorchDevice = TorchDevice.TorchCPU
): Unit {
TorchTensorRealAlgebra {
println("Benchmarking $scale x $scale matrices over Double's: ")
setSeed(SEED)
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
lhs dotAssign rhs
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
}
}
class BenchmarksDouble {
@Test
fun benchmarkMatrixMultiplication20() = benchmarkingDoubleMatrixMultiplication(20, 100000)
@Test
fun benchmarkMatrixMultiplication200() = benchmarkingDoubleMatrixMultiplication(200, 10000)
@Test
fun benchmarkMatrixMultiplication2000() = benchmarkingDoubleMatrixMultiplication(2000, 10)
}

View File

@ -33,9 +33,12 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
lhs dot rhs lhs dot rhs
} }
lhsTensor dotAssign rhsTensor
var error: Double = 0.0 var error: Double = 0.0
product.elements().forEach { product.elements().forEach {
error += abs(expected[it.first] - it.second) error += abs(expected[it.first] - it.second) +
abs(expected[it.first] - lhsTensor[it.first])
} }
assertTrue(error < TOLERANCE) assertTrue(error < TOLERANCE)
} }