diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 3ee071c77..a3af80ccd 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,14 +1,15 @@ package scientifik.kmath.asm +import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific -import scientifik.kmath.ast.* +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.MSTExpression import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.NumericAlgebra -import scientifik.kmath.operations.RealField import kotlin.reflect.KClass /** @@ -18,6 +19,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< fun AsmBuilder.visit(node: MST) { when (node) { is MST.Symbolic -> loadVariable(node.value) + is MST.Numeric -> { val constant = if (algebra is NumericAlgebra) algebra.number(node.value) @@ -26,48 +28,49 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } + is MST.Unary -> { loadAlgebra() - - if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation) - + if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) visit(node.value) - if (!tryInvokeSpecific(algebra, node.operation, 1)) { - invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "unaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } + if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = "unaryOperation", + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + AsmBuilder.OBJECT_TYPE + ), + + tArity = 1 + ) } is MST.Binary -> { loadAlgebra() - - if (!hasSpecific(algebra, node.operation, 2)) - loadStringConstant(node.operation) - + if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) visit(node.left) visit(node.right) - if (!tryInvokeSpecific(algebra, node.operation, 2)) { + if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = "binaryOperation", - invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "binaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + AsmBuilder.OBJECT_TYPE, + AsmBuilder.OBJECT_TYPE + ), + + tArity = 2 + ) } } } - return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() + return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } /** @@ -79,7 +82,3 @@ inline fun Algebra.expression(mst: MST): Expression = ms * Optimize performance of an [MSTExpression] using ASM codegen */ inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) - -fun main() { - RealField.mstInField { symbol("x") + 2 }.compile() -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 9577d19e1..8f45c4044 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -1,25 +1,25 @@ package scientifik.kmath.asm.internal -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Label -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.AALOAD +import org.objectweb.asm.Opcodes.RETURN +import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader +import scientifik.kmath.ast.MST import scientifik.kmath.operations.Algebra +import java.util.* +import kotlin.reflect.KClass /** - * ASM Builder is a structure that abstracts building a class for unwrapping [MST] to plain Java expression. + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * - * @param T the type generated [AsmCompiledExpression] operates. - * @property classOfT the [Class] of T. - * @property algebra the algebra the applied AsmExpressions use. - * @property className the unique class name of new loaded class. - * @property invokeLabel0Visitor the function applied to this object when the L0 label of invoke implementation is - * being written. + * @param T the type of AsmExpression to unwrap. + * @param algebra the algebra the applied AsmExpressions use. + * @param className the unique class name of new loaded class. */ internal class AsmBuilder internal constructor( - private val classOfT: Class<*>, + private val classOfT: KClass<*>, private val algebra: Algebra, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit @@ -34,16 +34,16 @@ internal class AsmBuilder internal constructor( /** * The instance of [ClassLoader] used by this builder. */ - private val classLoader: ClassLoader = - ClassLoader(javaClass.classLoader) + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") - private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + private val T_ALGEBRA_TYPE: Type = algebra::class.asm @Suppress("PrivatePropertyName") - private val T_CLASS: String = classOfT.name.replace('.', '/') + internal val T_TYPE: Type = classOfT.asm - private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + @Suppress("PrivatePropertyName") + private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. @@ -63,7 +63,16 @@ internal class AsmBuilder internal constructor( /** * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. */ - private lateinit var invokeMethodVisitor: MethodVisitor + private lateinit var invokeMethodVisitor: InstructionAdapter + internal var primitiveMode = false + + @Suppress("PropertyName") + internal var PRIMITIVE_MASK: Type = OBJECT_TYPE + + @Suppress("PropertyName") + internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE + private val typeStack = Stack() + internal val expectationStack = Stack().apply { push(T_TYPE) } /** * The cache of [AsmCompiledExpression] subclass built by this builder. @@ -76,81 +85,88 @@ internal class AsmBuilder internal constructor( * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - internal fun getInstance(): AsmCompiledExpression { + fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } + if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { + primitiveMode = true + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) + PRIMITIVE_MASK_BOXED = T_TYPE + } + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - slashesClassName, - "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + CLASS_TYPE.internalName, + "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", + ASM_COMPILED_EXPRESSION_TYPE.internalName, arrayOf() ) visitMethod( - access = Opcodes.ACC_PUBLIC, - name = "", - descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", - signature = null, - exceptions = null - ) { + Opcodes.ACC_PUBLIC, + "", + Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + null, + null + ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 val l0 = Label() visitLabel(l0) - visitLoadObjectVar(thisVar) - visitLoadObjectVar(algebraVar) - visitLoadObjectVar(constantsVar) + load(thisVar, CLASS_TYPE) + load(algebraVar, ALGEBRA_TYPE) + load(constantsVar, OBJECT_ARRAY_TYPE) - visitInvokeSpecial( - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + invokespecial( + ASM_COMPILED_EXPRESSION_TYPE.internalName, "", - "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" + Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + false ) val l1 = Label() visitLabel(l1) - visitReturn() + visitInsn(RETURN) val l2 = Label() visitLabel(l2) - visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) visitLocalVariable( "algebra", - "L$ALGEBRA_CLASS;", - "L$ALGEBRA_CLASS;", + ALGEBRA_TYPE.descriptor, + "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", l0, l2, algebraVar ) - visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) visitMaxs(0, 3) visitEnd() } visitMethod( - access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, - name = "invoke", - descriptor = "(L$MAP_CLASS;)L$T_CLASS;", - signature = "(L$MAP_CLASS;)L$T_CLASS;", - exceptions = null - ) { + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + Type.getMethodDescriptor(T_TYPE, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + null + ).instructionAdapter { invokeMethodVisitor = this visitCode() val l0 = Label() visitLabel(l0) invokeLabel0Visitor() - visitReturnObject() + areturn(T_TYPE) val l1 = Label() visitLabel(l1) visitLocalVariable( "this", - "L$slashesClassName;", + CLASS_TYPE.descriptor, null, l0, l1, @@ -159,8 +175,8 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", l0, l1, invokeArgumentsVar @@ -171,28 +187,28 @@ internal class AsmBuilder internal constructor( } visitMethod( - access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, - name = "invoke", - descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", - signature = null, - exceptions = null - ) { + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null + ).instructionAdapter { val thisVar = 0 val argumentsVar = 1 visitCode() val l0 = Label() visitLabel(l0) - visitLoadObjectVar(thisVar) - visitLoadObjectVar(argumentsVar) - visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") - visitReturnObject() + load(thisVar, OBJECT_TYPE) + load(argumentsVar, MAP_TYPE) + invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) + areturn(T_TYPE) val l1 = Label() visitLabel(l1) visitLocalVariable( "this", - "L$slashesClassName;", - T_CLASS, + CLASS_TYPE.descriptor, + null, l0, l1, thisVar @@ -219,89 +235,111 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (classOfT in INLINABLE_NUMBERS) { - loadNumberConstant(value as Number) - invokeMethodVisitor.visitCheckCast(T_CLASS) + if (classOfT.java in INLINABLE_NUMBERS) { + val expectedType = expectationStack.pop()!! + val mustBeBoxed = expectedType.sort == Type.OBJECT + loadNumberConstant(value as Number, mustBeBoxed) + if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) return } - loadConstant(value as Any, T_CLASS) + loadConstant(value as Any, T_TYPE) } - /** - * Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. - */ - private fun loadConstant(value: Any, type: String) { + private fun box(): Unit = invokeMethodVisitor.invokestatic( + T_TYPE.internalName, + "valueOf", + Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + false + ) + + private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( + NUMBER_TYPE.internalName, + NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), + Type.getMethodDescriptor(PRIMITIVE_MASK), + false + ) + + private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - - invokeMethodVisitor.run { - loadThis() - visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") - visitLdcOrIntConstant(idx) - visitGetObjectArrayElement() - invokeMethodVisitor.visitCheckCast(type) - } + loadThis() + getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive + * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ - private fun loadNumberConstant(value: Number) { - val clazz = value.javaClass - val c = clazz.name.replace('.', '/') - val sigLetter = SIGNATURE_LETTERS[clazz] + private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { + val boxed = value::class.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] - if (sigLetter != null) { - when (value) { - is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) - is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) - is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) - is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) - is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) - is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) - else -> invokeMethodVisitor.visitLdcInsn(value) + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + if (mustBeBoxed) { + box() + invokeMethodVisitor.checkcast(T_TYPE) } - invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") return } - loadConstant(value, c) + loadConstant(value, boxed) + if (!mustBeBoxed) unbox() + else invokeMethodVisitor.checkcast(T_TYPE) } /** * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - visitLoadObjectVar(invokeArgumentsVar) + load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) if (defaultValue != null) { - visitLdcInsn(name) + loadStringConstant(name) loadTConstant(defaultValue) - visitInvokeInterface( - owner = MAP_CLASS, - name = "getOrDefault", - descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" + invokeinterface( + MAP_TYPE.internalName, + "getOrDefault", + Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.visitCheckCast(T_CLASS) + invokeMethodVisitor.checkcast(T_TYPE) return } - visitLdcInsn(name) + loadStringConstant(name) - visitInvokeInterface( - owner = MAP_CLASS, - name = "get", - descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" + invokeinterface( + MAP_TYPE.internalName, + "get", + Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.visitCheckCast(T_CLASS) + invokeMethodVisitor.checkcast(T_TYPE) + val expectedType = expectationStack.pop()!! + + if (expectedType.sort == Type.OBJECT) + typeStack.push(T_TYPE) + else { + unbox() + typeStack.push(PRIMITIVE_MASK) + } } /** @@ -310,13 +348,10 @@ internal class AsmBuilder internal constructor( internal fun loadAlgebra() { loadThis() - invokeMethodVisitor.visitGetField( - owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, - name = "algebra", - descriptor = "L$ALGEBRA_CLASS;" - ) - - invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) + invokeMethodVisitor.run { + getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) + checkcast(T_ALGEBRA_TYPE) + } } /** @@ -330,29 +365,70 @@ internal class AsmBuilder internal constructor( owner: String, method: String, descriptor: String, + tArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { - invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, opcode == Opcodes.INVOKEINTERFACE) - invokeMethodVisitor.visitCheckCast(T_CLASS) + repeat(tArity) { if (!typeStack.empty()) typeStack.pop() } + + invokeMethodVisitor.visitMethodInsn( + opcode, + owner, + method, + descriptor, + opcode == Opcodes.INVOKEINTERFACE + ) + + invokeMethodVisitor.checkcast(T_TYPE) + val isLastExpr = expectationStack.size == 1 + val expectedType = expectationStack.pop()!! + + if (expectedType.sort == Type.OBJECT || isLastExpr) + typeStack.push(T_TYPE) + else { + unbox() + typeStack.push(PRIMITIVE_MASK) + } } /** * Writes a LDC Instruction with string constant provided. */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) internal companion object { /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, String> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" + java.lang.Byte::class.java to Type.BYTE_TYPE, + java.lang.Short::class.java to Type.SHORT_TYPE, + java.lang.Integer::class.java to Type.INT_TYPE, + java.lang.Long::class.java to Type.LONG_TYPE, + java.lang.Float::class.java to Type.FLOAT_TYPE, + java.lang.Double::class.java to Type.DOUBLE_TYPE + ) + } + + private val BOXED_TO_PRIMITIVES: Map by lazy { + hashMapOf( + java.lang.Byte::class.asm to Type.BYTE_TYPE, + java.lang.Short::class.asm to Type.SHORT_TYPE, + java.lang.Integer::class.asm to Type.INT_TYPE, + java.lang.Long::class.asm to Type.LONG_TYPE, + java.lang.Float::class.asm to Type.FLOAT_TYPE, + java.lang.Double::class.asm to Type.DOUBLE_TYPE + ) + } + + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( + Type.BYTE_TYPE to "byteValue", + Type.SHORT_TYPE to "shortValue", + Type.INT_TYPE to "intValue", + Type.LONG_TYPE to "longValue", + Type.FLOAT_TYPE to "floatValue", + Type.DOUBLE_TYPE to "doubleValue" ) } @@ -360,13 +436,14 @@ internal class AsmBuilder internal constructor( * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm + internal val NUMBER_TYPE: Type = java.lang.Number::class.asm + internal val MAP_TYPE: Type = java.util.Map::class.asm + internal val OBJECT_TYPE: Type = java.lang.Object::class.asm - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/internal/AsmCompiledExpression" - - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val STRING_CLASS = "java/lang/String" + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") + internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm + internal val ALGEBRA_TYPE: Type = Algebra::class.asm + internal val STRING_TYPE: Type = java.lang.String::class.asm } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt index a5236927c..7c4a9fc99 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -16,4 +16,3 @@ internal abstract class AsmCompiledExpression internal constructor( ) : Expression { abstract override fun invoke(arguments: Map): T } - diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt new file mode 100644 index 000000000..dc0b35531 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt @@ -0,0 +1,7 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Type +import kotlin.reflect.KClass + +internal val KClass<*>.asm: Type + get() = Type.getType(java) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 6ecce8833..7b0d346b7 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -1,60 +1,9 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.commons.InstructionAdapter -internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { - -1 -> visitInsn(ICONST_M1) - 0 -> visitInsn(ICONST_0) - 1 -> visitInsn(ICONST_1) - 2 -> visitInsn(ICONST_2) - 3 -> visitInsn(ICONST_3) - 4 -> visitInsn(ICONST_4) - 5 -> visitInsn(ICONST_5) - in -128..127 -> visitIntInsn(BIPUSH, value) - in -32768..32767 -> visitIntInsn(SIPUSH, value) - else -> visitLdcInsn(value) -} +fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { - 0.0 -> visitInsn(DCONST_0) - 1.0 -> visitInsn(DCONST_1) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) { - 0L -> visitInsn(LCONST_0) - 1L -> visitInsn(LCONST_1) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { - 0f -> visitInsn(FCONST_0) - 1f -> visitInsn(FCONST_1) - 2f -> visitInsn(FCONST_2) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true) - -internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false) - -internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false) - -internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false) - -internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type) - -internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit = - visitFieldInsn(GETFIELD, owner, name, descriptor) - -internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`) - -internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD) - -internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN) -internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN) +fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt similarity index 53% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index efad97763..2e15a1a93 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -1,6 +1,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map by lazy { @@ -12,17 +13,19 @@ private val methodNameAdapters: Map by lazy { } /** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity]. + * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds + * type expectation stack for needed arity. * * @return `true` if contains, else `false`. */ -internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } - ?: return false + val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + repeat(arity) { expectationStack.push(t) } - return true + return hasSpecific } /** @@ -34,22 +37,23 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } - ?: return false + val method = + context.javaClass.methods.find { + var suitableSignature = it.name == aName && it.parameters.size == arity + + if (primitiveMode && it.isBridge) + suitableSignature = false + + suitableSignature + } ?: return false val owner = context::class.java.name.replace('.', '/') - val sig = buildString { - append('(') - repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } - append(')') - append("L${AsmBuilder.OBJECT_CLASS};") - } - invokeAlgebraOperation( owner = owner, method = aName, - descriptor = sig, + descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }), + tArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index a6cb1e247..4c2be811e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,66 +1,110 @@ package scietifik.kmath.asm -// -//class TestAsmAlgebras { -// @Test -// fun space() { -// val res = ByteRing.asmInRing { -// binaryOperation( -// "+", -// -// unaryOperation( -// "+", -// 3.toByte() - (2.toByte() + (multiply( -// add(number(1), number(1)), -// 2 -// ) + 1.toByte()) * 3.toByte() - 1.toByte()) -// ), -// -// number(1) -// ) + symbol("x") + zero -// }("x" to 2.toByte()) -// -// assertEquals(16, res) -// } -// -// @Test -// fun ring() { -// val res = ByteRing.asmInRing { -// binaryOperation( -// "+", -// -// unaryOperation( -// "+", -// (3.toByte() - (2.toByte() + (multiply( -// add(const(1), const(1)), -// 2 -// ) + 1.toByte()))) * 3.0 - 1.toByte() -// ), -// -// number(1) -// ) * const(2) -// }() -// -// assertEquals(24, res) -// } -// -// @Test -// fun field() { -// val res = RealField.asmInField { -// +(3 - 2 + 2*(number(1)+1.0) -// -// unaryOperation( -// "+", -// (3.0 - (2.0 + (multiply( -// add((1.0), const(1.0)), -// 2 -// ) + 1.0))) * 3 - 1.0 -// )+ -// -// number(1) -// ) / 2, const(2.0)) * one -// }() -// -// assertEquals(3.0, res) -// } -//} +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInField +import scientifik.kmath.ast.mstInRing +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmAlgebras { + @Test + fun space() { + val res1 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to 2.toByte()) + + val res2 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to 2.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun ring() { + val res1 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }("x" to 3.toByte()) + + val res2 = ByteRing.mstInRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (symbol("x") - (2.toByte() + (multiply( + add(number(1), number(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * number(2) + }.compile()("x" to 3.toByte()) + + assertEquals(res1, res2) + } + + @Test + fun field() { + val res1 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + 1 / 2 + number(2.0) * one + ) + }("x" to 2.0) + + val res2 = RealField.mstInField { + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( + "+", + (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + + number(1), + 1 / 2 + number(2.0) * one + ) + }.compile()("x" to 2.0) + + assertEquals(res1, res2) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 82af1a927..824201aa7 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -16,6 +16,13 @@ class TestAsmExpressions { assertEquals(-2.0, res) } + @Test + fun testBinaryOperationInvocation() { + val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile() + val res = expression("x" to 2.0) + assertEquals(-1.0, res) + } + @Test fun testConstProductInvocation() { val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index f0f9a8bc1..08d7fff47 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -1,7 +1,10 @@ package scietifik.kmath.ast -import scientifik.kmath.ast.evaluate +import scientifik.kmath.asm.compile +import scientifik.kmath.asm.expression +import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField import kotlin.test.Test @@ -9,9 +12,15 @@ import kotlin.test.assertEquals class AsmTest { @Test - fun parsedExpression() { + fun `compile MST`() { val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.evaluate(mst) + val res = ComplexField.expression(mst)() + assertEquals(Complex(10.0, 0.0), res) + } + + @Test + fun `compile MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()() assertEquals(Complex(10.0, 0.0), res) } } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt index 06546302e..5394a4b00 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -1,17 +1,25 @@ package scietifik.kmath.ast import scientifik.kmath.ast.evaluate +import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField -import kotlin.test.assertEquals import kotlin.test.Test +import kotlin.test.assertEquals internal class ParserTest { @Test - fun parsedExpression() { + fun `evaluate MST`() { val mst = "2+2*(2+2)".parseMath() val res = ComplexField.evaluate(mst) assertEquals(Complex(10.0, 0.0), res) } + + @Test + fun `evaluate MSTExpression`() { + val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() + assertEquals(Complex(10.0, 0.0), res) + } }