diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index 1edbed28d..93d8d1143 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -191,7 +191,7 @@ internal class AsmBuilder( } val cls = classLoader.defineClass(className, classWriter.toByteArray()) - java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + // java.io.File("dump.class").writeBytes(classWriter.toByteArray()) val l = MethodHandles.publicLookup() if (hasConstants) diff --git a/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt b/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt index 1259f58b9..0d8fbfbea 100644 --- a/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt +++ b/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt @@ -1,6 +1,5 @@ package kscience.kmath.gsl -import kotlinx.cinterop.memScoped import org.gnu.gsl.gsl_block_calloc import kotlin.test.Test import kotlin.test.assertFailsWith @@ -17,7 +16,7 @@ internal class ErrorHandler { @Test fun matrixAllocation() { assertFailsWith { - memScoped { GslRealMatrixContext(this).produce(Int.MAX_VALUE, Int.MAX_VALUE) { _, _ -> 0.0 } } + GslRealMatrixContext { produce(Int.MAX_VALUE, Int.MAX_VALUE) { _, _ -> 0.0 } } } } } diff --git a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt index e0a7eadb7..01d9ef90d 100644 --- a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt +++ b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt @@ -1,46 +1,71 @@ package kscience.kmath.gsl -import kotlinx.cinterop.memScoped import kscience.kmath.linear.RealMatrixContext import kscience.kmath.operations.invoke import kscience.kmath.structures.* import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue +import kotlin.time.measureTime internal class RealTest { @Test - fun testScale() = memScoped { - (GslRealMatrixContext(this)) { - val ma = produce(10, 10) { _, _ -> 0.1 } - val mb = (ma * 20.0) - assertEquals(mb[0, 1], 2.0) - } + fun testScale() = GslRealMatrixContext { + val ma = produce(10, 10) { _, _ -> 0.1 } + val mb = (ma * 20.0) + assertEquals(mb[0, 1], 2.0) } @Test - fun testDotOfMatrixAndVector() = memScoped { - (GslRealMatrixContext(this)) { - val ma = produce(2, 2) { _, _ -> 100.0 } - val vb = RealBuffer(2) { 0.1 } - val res1 = ma dot vb - val res2 = RealMatrixContext { ma dot vb } - println(res1.asSequence().toList()) - println(res2.asSequence().toList()) - assertTrue(res1.contentEquals(res2)) - } + fun testDotOfMatrixAndVector() = GslRealMatrixContext { + val ma = produce(2, 2) { _, _ -> 100.0 } + val vb = RealBuffer(2) { 0.1 } + val res1 = ma dot vb + val res2 = RealMatrixContext { ma dot vb } + println(res1.asSequence().toList()) + println(res2.asSequence().toList()) + assertTrue(res1.contentEquals(res2)) } @Test - fun testDotOfMatrixAndMatrix() = memScoped { - (GslRealMatrixContext(this)) { - val ma = produce(2, 2) { _, _ -> 100.0 } - val mb = produce(2, 2) { _, _ -> 100.0 } - val res1: Matrix = ma dot mb - val res2: Matrix = RealMatrixContext { ma dot mb } - println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList)) - println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList)) - assertEquals(res1, res2) + fun testDotOfMatrixAndMatrix() = GslRealMatrixContext { + val ma = produce(2, 2) { _, _ -> 100.0 } + val mb = produce(2, 2) { _, _ -> 100.0 } + val res1: Matrix = ma dot mb + val res2: Matrix = RealMatrixContext { ma dot mb } + println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList)) + println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList)) + assertEquals(res1, res2) + } + + @Test + fun testManyCalls() { + val r1 = GslRealMatrixContext { + var prod = produce(20, 20) { _, _ -> 100.0 } + val mult = produce(20, 20) { _, _ -> 3.0 } + + measureTime { + repeat(100) { + prod = prod dot mult + } + }.also(::println) + + prod } + + val r2 = RealMatrixContext { + var prod = produce(20, 20) { _, _ -> 100.0 } + val mult = produce(20, 20) { _, _ -> 3.0 } + + measureTime { + repeat(100) { + prod = prod dot mult + } + }.also(::println) + + prod + } + + assertTrue(NDStructure.equals(r1, r2)) } }