diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb456e6eb..af39d9091 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -47,6 +47,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< expectedArity = 1 ) } + is MST.Binary -> { loadAlgebra() if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index ebe25e19f..89c9dca9d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -354,31 +354,23 @@ internal class AsmBuilder internal constructor( * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) + load(invokeArgumentsVar, MAP_TYPE) + aconst(name) - if (defaultValue != null) { - loadStringConstant(name) + if (defaultValue != null) loadTConstant(defaultValue) + else + aconst(null) - invokeinterface( - MAP_TYPE.internalName, - "getOrDefault", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) - ) - - invokeMethodVisitor.checkcast(tType) - return - } - - loadStringConstant(name) - - invokeinterface( - MAP_TYPE.internalName, - "get", - Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE), + false ) - invokeMethodVisitor.checkcast(tType) + checkcast(tType) + val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) @@ -517,5 +509,10 @@ internal class AsmBuilder internal constructor( * ASM type for [java.lang.String]. */ internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } + + /** + * ASM type for MapIntrinsics. + */ + internal val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("scientifik/kmath/asm/internal/MapIntrinsics") } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt new file mode 100644 index 000000000..7f7126b55 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -0,0 +1,7 @@ +@file:JvmName("MapIntrinsics") + +package scientifik.kmath.asm.internal + +internal fun Map.getOrFail(key: K, default: V?): V { + return this[key] ?: default ?: error("Parameter not found: $key") +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index e54acf6f9..a8d5a605f 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -4,8 +4,15 @@ import org.objectweb.asm.Opcodes import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra -private val methodNameAdapters: Map by lazy { - hashMapOf("+" to "add", "*" to "multiply", "/" to "divide") +private val methodNameAdapters: Map, String> by lazy { + hashMapOf( + "+" to 2 to "add", + "*" to 2 to "multiply", + "/" to 2 to "divide", + "+" to 1 to "unaryPlus", + "-" to 1 to "unaryMinus", + "-" to 2 to "minus" + ) } /** @@ -15,7 +22,7 @@ private val methodNameAdapters: Map by lazy { * @return `true` if contains, else `false`. */ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name to arity] ?: name val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType repeat(arity) { expectationStack.push(t) } @@ -29,7 +36,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * @return `true` if contains, else `false`. */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name] ?: name + val theName = methodNameAdapters[name to arity] ?: name context.javaClass.methods.find { var suitableSignature = it.name == theName && it.parameters.size == arity diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 079a6be0f..3acc6eb28 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmAlgebras { +internal class TestAsmAlgebras { @Test fun space() { val res1 = ByteRing.mstInSpace { diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 824201aa7..36c254c38 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -8,7 +8,7 @@ import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals -class TestAsmExpressions { +internal class TestAsmExpressions { @Test fun testUnaryOperationInvocation() { val expression = RealField.mstInSpace { -symbol("x") }.compile() diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt new file mode 100644 index 000000000..f3b07df56 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestSpecialization.kt @@ -0,0 +1,45 @@ +package scietifik.kmath.asm + +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class TestSpecialization { + @Test + fun testUnaryPlus() { + val expr = RealField.mstInField { unaryOperation("+", symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(2.0, res) + } + + @Test + fun testUnaryMinus() { + val expr = RealField.mstInField { unaryOperation("-", symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(-2.0, res) + } + + @Test + fun testAdd() { + val expr = RealField.mstInField { binaryOperation("+", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(4.0, res) + } + + @Test + fun testMinus() { + val expr = RealField.mstInField { binaryOperation("-", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(0.0, res) + } + + @Test + fun testDivide() { + val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() + val res = expr("x" to 2.0) + assertEquals(1.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 08d7fff47..23203172e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -10,7 +10,7 @@ import scientifik.kmath.operations.ComplexField import kotlin.test.Test import kotlin.test.assertEquals -class AsmTest { +internal class AsmTest { @Test fun `compile MST`() { val mst = "2+2*(2+2)".parseMath()