diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 46739c49c..e4bd705b4 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,6 +1,7 @@ package scientifik.kmath.asm import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST @@ -10,44 +11,30 @@ import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra import kotlin.reflect.KClass - /** * Compile given MST to an Expression using AST compiler */ fun MST.compileWith(type: KClass, algebra: Algebra): Expression { - - fun buildName(mst: MST, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(mst, collision + 1) - } - - fun AsmBuilder.visit(node: MST): Unit { + fun AsmBuilder.visit(node: MST) { when (node) { - is MST.Symbolic -> visitLoadFromVariables(node.value) + is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { val constant = if (algebra is NumericAlgebra) { algebra.number(node.value) } else { error("Number literals are not supported in $algebra") } - visitLoadFromConstants(constant) + loadTConstant(constant) } is MST.Unary -> { - visitLoadAlgebra() + loadAlgebra() - if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation) + if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation) visit(node.value) if (!tryInvokeSpecific(algebra, node.operation, 1)) { - visitAlgebraOperation( + invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_CLASS, method = "unaryOperation", descriptor = "(L${AsmBuilder.STRING_CLASS};" + @@ -57,17 +44,17 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< } } is MST.Binary -> { - visitLoadAlgebra() + loadAlgebra() if (!hasSpecific(algebra, node.operation, 2)) - visitStringConstant(node.operation) + loadStringConstant(node.operation) visit(node.left) visit(node.right) if (!tryInvokeSpecific(algebra, node.operation, 2)) { - visitAlgebraOperation( + invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_CLASS, method = "binaryOperation", descriptor = "(L${AsmBuilder.STRING_CLASS};" + @@ -80,9 +67,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< } } - val builder = AsmBuilder(type.java, algebra, buildName(this)) - builder.visit(this) - return builder.generate() + return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } inline fun Algebra.compile(mst: MST): Expression = mst.compileWith(T::class, this) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 076093ee1..337219029 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -4,41 +4,29 @@ import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader -import scientifik.kmath.expressions.Expression +import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra - -private abstract class FunctionalCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: Array -) : Expression { - abstract override fun invoke(arguments: Map): T -} - - /** - * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java - * expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new - * class. + * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. + * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. - * - * @author [Iaroslav Postovalov](https://github.com/CommanderTvis) */ -internal class AsmBuilder( +internal class AsmBuilder internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, - private val className: String + private val className: String, + private val invokeLabel0Visitor: AsmBuilder.() -> Unit ) { - private class AsmClassLoader(parent: ClassLoader) : ClassLoader(parent) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } - private val classLoader: AsmClassLoader = - AsmClassLoader(javaClass.classLoader) + private val classLoader: ClassLoader = + ClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') @@ -49,45 +37,49 @@ internal class AsmBuilder( private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') private val invokeThisVar: Int = 0 private val invokeArgumentsVar: Int = 1 - private var maxStack: Int = 0 private val constants: MutableList = mutableListOf() - private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) - private val invokeMethodVisitor: MethodVisitor - private val invokeL0: Label - private lateinit var invokeL1: Label + private lateinit var invokeMethodVisitor: MethodVisitor + private var generatedInstance: AsmCompiledExpression? = null - init { - asmCompiledClassWriter.visit( - Opcodes.V1_8, - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - slashesClassName, - "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, - arrayOf() - ) + @Suppress("UNCHECKED_CAST") + fun getInstance(): AsmCompiledExpression { + generatedInstance?.let { return it } - asmCompiledClassWriter.run { - visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run { + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { + visit( + Opcodes.V1_8, + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + slashesClassName, + "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + arrayOf() + ) + + visitMethod( + access = Opcodes.ACC_PUBLIC, + name = "", + descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", + signature = null, + exceptions = null + ) { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 val l0 = Label() visitLabel(l0) - visitVarInsn(Opcodes.ALOAD, thisVar) - visitVarInsn(Opcodes.ALOAD, algebraVar) - visitVarInsn(Opcodes.ALOAD, constantsVar) + visitLoadObjectVar(thisVar) + visitLoadObjectVar(algebraVar) + visitLoadObjectVar(constantsVar) - visitMethodInsn( - Opcodes.INVOKESPECIAL, + visitInvokeSpecial( FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "", - "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", - false + "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" ) val l1 = Label() visitLabel(l1) - visitInsn(Opcodes.RETURN) + visitReturn() val l2 = Label() visitLabel(l2) visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) @@ -102,220 +94,210 @@ internal class AsmBuilder( ) visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) - visitMaxs(3, 3) + visitMaxs(0, 3) visitEnd() } - invokeMethodVisitor = visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, - "invoke", - "(L$MAP_CLASS;)L$T_CLASS;", - "(L$MAP_CLASS;)L$T_CLASS;", - null - ) - - invokeMethodVisitor.run { + visitMethod( + access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + name = "invoke", + descriptor = "(L$MAP_CLASS;)L$T_CLASS;", + signature = "(L$MAP_CLASS;)L$T_CLASS;", + exceptions = null + ) { + invokeMethodVisitor = this visitCode() - invokeL0 = Label() - visitLabel(invokeL0) + val l0 = Label() + visitLabel(l0) + invokeLabel0Visitor() + visitReturnObject() + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + null, + l0, + l1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + l0, + l1, + invokeArgumentsVar + ) + + visitMaxs(0, 2) + visitEnd() } - } - } - @Suppress("UNCHECKED_CAST") - fun generate(): Expression { + visitMethod( + access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + name = "invoke", + descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", + signature = null, + exceptions = null + ) { + val thisVar = 0 + val argumentsVar = 1 + visitCode() + val l0 = Label() + visitLabel(l0) + visitLoadObjectVar(thisVar) + visitLoadObjectVar(argumentsVar) + visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") + visitReturnObject() + val l1 = Label() + visitLabel(l1) - invokeMethodVisitor.run { - visitInsn(Opcodes.ARETURN) - invokeL1 = Label() - visitLabel(invokeL1) + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - invokeL0, - invokeL1, - invokeThisVar - ) + visitMaxs(0, 2) + visitEnd() + } - visitLocalVariable( - "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", - invokeL0, - invokeL1, - invokeArgumentsVar - ) - - visitMaxs(maxStack + 1, 2) visitEnd() } - asmCompiledClassWriter.visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, - "invoke", - "(L$MAP_CLASS;)L$OBJECT_CLASS;", - null, - null - ).run { - val thisVar = 0 - visitCode() - val l0 = Label() - visitLabel(l0) - visitVarInsn(Opcodes.ALOAD, 0) - visitVarInsn(Opcodes.ALOAD, 1) - visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) - visitInsn(Opcodes.ARETURN) - val l1 = Label() - visitLabel(l1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - l0, - l1, - thisVar - ) - - visitMaxs(2, 2) - visitEnd() - } - - asmCompiledClassWriter.visitEnd() - val new = classLoader - .defineClass(className, asmCompiledClassWriter.toByteArray()) + .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression + generatedInstance = new return new } - fun visitLoadFromConstants(value: T) { + internal fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { - visitNumberConstant(value as Number) - visitCastToT() + loadNumberConstant(value as Number) + invokeMethodVisitor.visitCheckCast(T_CLASS) return } - visitLoadAnyFromConstants(value as Any, T_CLASS) + loadConstant(value as Any, T_CLASS) } - private fun visitLoadAnyFromConstants(value: Any, type: String) { + private fun loadConstant(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - maxStack++ invokeMethodVisitor.run { - visitLoadThis() - visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;") - visitLdcOrIConstInsn(idx) - visitInsn(Opcodes.AALOAD) - invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) + loadThis() + visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") + visitLdcOrIntConstant(idx) + visitGetObjectArrayElement() + invokeMethodVisitor.visitCheckCast(type) } } - private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) - fun visitNumberConstant(value: Number) { - maxStack++ + internal fun loadNumberConstant(value: Number) { val clazz = value.javaClass val c = clazz.name.replace('.', '/') val sigLetter = SIGNATURE_LETTERS[clazz] if (sigLetter != null) { when (value) { - is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) - is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) - is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) + is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) + is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) + is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) else -> invokeMethodVisitor.visitLdcInsn(value) } - invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") return } - visitLoadAnyFromConstants(value, c) + loadConstant(value, c) } - fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - maxStack += 2 - visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) + internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + visitLoadObjectVar(invokeArgumentsVar) if (defaultValue != null) { visitLdcInsn(name) - visitLoadFromConstants(defaultValue) + loadTConstant(defaultValue) - visitMethodInsn( - Opcodes.INVOKEINTERFACE, - MAP_CLASS, - "getOrDefault", - "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", - true + visitInvokeInterface( + owner = MAP_CLASS, + name = "getOrDefault", + descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" ) - visitCastToT() + invokeMethodVisitor.visitCheckCast(T_CLASS) return } visitLdcInsn(name) - visitMethodInsn( - Opcodes.INVOKEINTERFACE, - MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true - ) - visitCastToT() - } - fun visitLoadAlgebra() { - maxStack++ - invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) - - invokeMethodVisitor.visitFieldInsn( - Opcodes.GETFIELD, - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" + visitInvokeInterface( + owner = MAP_CLASS, + name = "get", + descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" ) - invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) + invokeMethodVisitor.visitCheckCast(T_CLASS) } - fun visitAlgebraOperation( + internal fun loadAlgebra() { + loadThis() + + invokeMethodVisitor.visitGetField( + owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + name = "algebra", + descriptor = "L$ALGEBRA_CLASS;" + ) + + invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) + } + + internal fun invokeAlgebraOperation( owner: String, method: String, descriptor: String, opcode: Int = Opcodes.INVOKEINTERFACE, isInterface: Boolean = true ) { - maxStack++ invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) - visitCastToT() + invokeMethodVisitor.visitCheckCast(T_CLASS) } - private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) - fun visitStringConstant(string: String) { - invokeMethodVisitor.visitLdcInsn(string) - } + internal companion object { + private val SIGNATURE_LETTERS: Map, String> by lazy { + mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) + } - companion object { - private val SIGNATURE_LETTERS = mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - private val INLINABLE_NUMBERS = SIGNATURE_LETTERS.keys + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = + "scientifik/kmath/asm/internal/AsmCompiledExpression" - const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/FunctionalCompiledExpression" - const val MAP_CLASS = "java/util/Map" - const val OBJECT_CLASS = "java/lang/Object" - const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" - const val STRING_CLASS = "java/lang/String" - const val NUMBER_CLASS = "java/lang/Number" + internal const val MAP_CLASS = "java/util/Map" + internal const val OBJECT_CLASS = "java/lang/Object" + internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + internal const val STRING_CLASS = "java/lang/String" } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt new file mode 100644 index 000000000..0e113099a --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -0,0 +1,12 @@ +package scientifik.kmath.asm.internal + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra + +internal abstract class AsmCompiledExpression internal constructor( + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: Array +) : Expression { + abstract override fun invoke(arguments: Map): T +} + diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt new file mode 100644 index 000000000..65652d72b --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.asm.internal + +import scientifik.kmath.ast.MST + +internal fun buildName(mst: MST, collision: Int = 0): String { + val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(mst, collision + 1) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt new file mode 100644 index 000000000..95d713b18 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -0,0 +1,15 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor + +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) + +internal inline fun ClassWriter.visitMethod( + access: Int, + name: String, + descriptor: String, + signature: String?, + exceptions: Array?, + block: MethodVisitor.() -> Unit +): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 356de4765..70fb3bd44 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* -internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { +internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { -1 -> visitInsn(ICONST_M1) 0 -> visitInsn(ICONST_0) 1 -> visitInsn(ICONST_1) @@ -14,15 +14,39 @@ internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { else -> visitLdcInsn(value) } -internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { +internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { 0.0 -> visitInsn(DCONST_0) 1.0 -> visitInsn(DCONST_1) else -> visitLdcInsn(value) } -internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { +internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { 0f -> visitInsn(FCONST_0) 1f -> visitInsn(FCONST_1) 2f -> visitInsn(FCONST_2) else -> visitLdcInsn(value) } + +internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true) + +internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false) + +internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false) + +internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit = + visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false) + +internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type) + +internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit = + visitFieldInsn(GETFIELD, owner, name, descriptor) + +internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`) + +internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD) + +internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN) +internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index 029939f16..a39dba6b1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -29,7 +29,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri append("L${AsmBuilder.OBJECT_CLASS};") } - visitAlgebraOperation( + invokeAlgebraOperation( owner = owner, method = aName, descriptor = sig, @@ -39,8 +39,3 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri return true } -// -//internal fun AsmExpression.optimize(): AsmExpression { -// val a = tryEvaluate() -// return if (a == null) this else AsmConstantExpression(type, algebra, a) -//}