Benchmarks for tensor matrix multiplication over Double
This commit is contained in:
parent
0cb2c3f0da
commit
7894799e8e
@ -163,6 +163,5 @@ TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
auto lhs_tensor = ctorch::cast(lhs);
|
||||
lhs_tensor = lhs_tensor.matmul(ctorch::cast(rhs));
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
package kscience.kmath.torch
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.time.measureTime
|
||||
|
||||
internal fun benchmarkingDoubleMatrixMultiplication(
|
||||
scale: Int,
|
||||
numIter: Int,
|
||||
device: TorchDevice = TorchDevice.TorchCPU
|
||||
): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
println("Benchmarking $scale x $scale matrices over Double's: ")
|
||||
setSeed(SEED)
|
||||
val lhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
val rhs = randNormal(shape = intArrayOf(scale, scale), device = device)
|
||||
lhs dotAssign rhs
|
||||
val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } }
|
||||
println(" ${measuredTime / numIter} p.o. with $numIter iterations")
|
||||
}
|
||||
}
|
||||
|
||||
class BenchmarksDouble {
|
||||
|
||||
@Test
|
||||
fun benchmarkMatrixMultiplication20() = benchmarkingDoubleMatrixMultiplication(20, 100000)
|
||||
@Test
|
||||
fun benchmarkMatrixMultiplication200() = benchmarkingDoubleMatrixMultiplication(200, 10000)
|
||||
@Test
|
||||
fun benchmarkMatrixMultiplication2000() = benchmarkingDoubleMatrixMultiplication(2000, 10)
|
||||
|
||||
|
||||
}
|
@ -33,9 +33,12 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
|
||||
lhs dot rhs
|
||||
}
|
||||
|
||||
lhsTensor dotAssign rhsTensor
|
||||
|
||||
var error: Double = 0.0
|
||||
product.elements().forEach {
|
||||
error += abs(expected[it.first] - it.second)
|
||||
error += abs(expected[it.first] - it.second) +
|
||||
abs(expected[it.first] - lhsTensor[it.first])
|
||||
}
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user