diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 468ed01ba..ef2330533 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,10 +1,8 @@ package scientifik.kmath.asm -import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder -import scientifik.kmath.asm.internal.buildExpectationStack +import scientifik.kmath.asm.internal.buildAlgebraOperationCall import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST import scientifik.kmath.ast.MstExpression import scientifik.kmath.expressions.Expression @@ -29,44 +27,21 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } - is MST.Unary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) - visit(node.value) + is MST.Unary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "unaryOperation", + arity = 1 + ) { visit(node.value) } - if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "unaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - expectedArity = 1 - ) - } - - is MST.Binary -> { - loadAlgebra() - if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) + is MST.Binary -> buildAlgebraOperationCall( + context = algebra, + name = node.operation, + fallbackMethodName = "binaryOperation", + arity = 2 + ) { visit(node.left) visit(node.right) - - if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = "binaryOperation", - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - AsmBuilder.OBJECT_TYPE, - AsmBuilder.OBJECT_TYPE - ), - - expectedArity = 2 - ) } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a6d2c045b..003dc47dd 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -21,7 +21,7 @@ private val methodNameAdapters: Map, String> by lazy { * * @return `true` if contains, else `false`. */ -internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { val theName = methodNameAdapters[name to arity] ?: name val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null val t = if (primitiveMode && hasSpecific) primitiveMask else tType @@ -35,7 +35,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: * * @return `true` if contains, else `false`. */ -internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val theName = methodNameAdapters[name to arity] ?: name context.javaClass.methods.find { @@ -59,3 +59,28 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri return true } + +internal fun AsmBuilder.buildAlgebraOperationCall( + context: Algebra, + name: String, + fallbackMethodName: String, + arity: Int, + parameters: AsmBuilder.() -> Unit +) { + loadAlgebra() + if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) + parameters() + + if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = fallbackMethodName, + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + *Array(arity) { AsmBuilder.OBJECT_TYPE } + ), + + expectedArity = arity + ) +}