diff --git a/kmath-torch/ctorch/src/ctorch.cc b/kmath-torch/ctorch/src/ctorch.cc index dec64c354..f5d1ce6ee 100644 --- a/kmath-torch/ctorch/src/ctorch.cc +++ b/kmath-torch/ctorch/src/ctorch.cc @@ -163,6 +163,5 @@ TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs) void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs) { - auto lhs_tensor = ctorch::cast(lhs); - lhs_tensor = lhs_tensor.matmul(ctorch::cast(rhs)); + ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs)); } \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarksDouble.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarksDouble.kt new file mode 100644 index 000000000..8024b3eb3 --- /dev/null +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/BenchmarksDouble.kt @@ -0,0 +1,32 @@ +package kscience.kmath.torch + +import kotlin.test.Test +import kotlin.time.measureTime + +internal fun benchmarkingDoubleMatrixMultiplication( + scale: Int, + numIter: Int, + device: TorchDevice = TorchDevice.TorchCPU +): Unit { + TorchTensorRealAlgebra { + println("Benchmarking $scale x $scale matrices over Double's: ") + setSeed(SEED) + val lhs = randNormal(shape = intArrayOf(scale, scale), device = device) + val rhs = randNormal(shape = intArrayOf(scale, scale), device = device) + lhs dotAssign rhs + val measuredTime = measureTime { repeat(numIter) { lhs dotAssign rhs } } + println(" ${measuredTime / numIter} p.o. with $numIter iterations") + } +} + +class BenchmarksDouble { + + @Test + fun benchmarkMatrixMultiplication20() = benchmarkingDoubleMatrixMultiplication(20, 100000) + @Test + fun benchmarkMatrixMultiplication200() = benchmarkingDoubleMatrixMultiplication(200, 10000) + @Test + fun benchmarkMatrixMultiplication2000() = benchmarkingDoubleMatrixMultiplication(2000, 10) + + +} \ No newline at end of file diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt index b49c6f53e..c0518c056 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestTorchTensorAlgebra.kt @@ -33,9 +33,12 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch lhs dot rhs } + lhsTensor dotAssign rhsTensor + var error: Double = 0.0 product.elements().forEach { - error += abs(expected[it.first] - it.second) + error += abs(expected[it.first] - it.second) + + abs(expected[it.first] - lhsTensor[it.first]) } assertTrue(error < TOLERANCE) }