From d6e7eb8143029541d2ff6c5ef320c627ccb66973 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Thu, 18 Jun 2020 11:35:20 +0700 Subject: [PATCH 1/3] Add advanced specialization for primitive non-bridge methods --- .../kotlin/scientifik/kmath/asm/asm.kt | 7 +- .../kmath/asm/internal/AsmBuilder.kt | 65 ++++++++++++++++++- .../{optimization.kt => specialization.kt} | 4 +- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 35 ++++++++-- 4 files changed, 98 insertions(+), 13 deletions(-) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{optimization.kt => specialization.kt} (90%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb091732e..2345757f1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -19,11 +19,10 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< when (node) { is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { - val constant = if (algebra is NumericAlgebra) { - algebra.number(node.value) - } else { + val constant = if (algebra is NumericAlgebra) + algebra.number(node.value) else error("Number literals are not supported in $algebra") - } + loadTConstant(constant) } is MST.Unary -> { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 337219029..351d5b754 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra +import java.io.File /** * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. @@ -39,12 +40,23 @@ internal class AsmBuilder internal constructor( private val invokeArgumentsVar: Int = 1 private val constants: MutableList = mutableListOf() private lateinit var invokeMethodVisitor: MethodVisitor + private var primitiveMode = false + internal var primitiveType = OBJECT_CLASS + internal var primitiveTypeSig = "L$OBJECT_CLASS;" + internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" private var generatedInstance: AsmCompiledExpression? = null @Suppress("UNCHECKED_CAST") fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } + if (classOfT in SIGNATURE_LETTERS) { + primitiveMode = true + primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT) + primitiveType = SIGNATURE_LETTERS.getValue(classOfT) + primitiveTypeReturnSig = "L$T_CLASS;" + } + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, @@ -110,6 +122,10 @@ internal class AsmBuilder internal constructor( val l0 = Label() visitLabel(l0) invokeLabel0Visitor() + + if (primitiveMode) + boxPrimitiveToNumber() + visitReturnObject() val l1 = Label() visitLabel(l1) @@ -171,6 +187,8 @@ internal class AsmBuilder internal constructor( visitEnd() } + File("dump.class").writeBytes(classWriter.toByteArray()) + val new = classLoader .defineClass(className, classWriter.toByteArray()) .constructors @@ -182,6 +200,11 @@ internal class AsmBuilder internal constructor( } internal fun loadTConstant(value: T) { + if (primitiveMode) { + loadNumberConstant(value as Number) + return + } + if (classOfT in INLINABLE_NUMBERS) { loadNumberConstant(value as Number) invokeMethodVisitor.visitCheckCast(T_CLASS) @@ -191,6 +214,25 @@ internal class AsmBuilder internal constructor( loadConstant(value as Any, T_CLASS) } + private fun unboxInPrimitiveMode() { + if (!primitiveMode) + return + + unboxNumberToPrimitive() + } + + private fun boxPrimitiveToNumber() { + invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;") + } + + private fun unboxNumberToPrimitive() { + invokeMethodVisitor.visitInvokeVirtual( + owner = NUMBER_CLASS, + name = NUMBER_CONVERTER_METHODS.getValue(primitiveType), + descriptor = "()${primitiveTypeSig}" + ) + } + private fun loadConstant(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex @@ -199,13 +241,14 @@ internal class AsmBuilder internal constructor( visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") visitLdcOrIntConstant(idx) visitGetObjectArrayElement() - invokeMethodVisitor.visitCheckCast(type) + visitCheckCast(type) + unboxInPrimitiveMode() } } private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) - internal fun loadNumberConstant(value: Number) { + private fun loadNumberConstant(value: Number) { val clazz = value.javaClass val c = clazz.name.replace('.', '/') val sigLetter = SIGNATURE_LETTERS[clazz] @@ -218,7 +261,9 @@ internal class AsmBuilder internal constructor( else -> invokeMethodVisitor.visitLdcInsn(value) } - invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + if (!primitiveMode) + invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + return } @@ -251,6 +296,7 @@ internal class AsmBuilder internal constructor( ) invokeMethodVisitor.visitCheckCast(T_CLASS) + unboxInPrimitiveMode() } internal fun loadAlgebra() { @@ -274,6 +320,7 @@ internal class AsmBuilder internal constructor( ) { invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) invokeMethodVisitor.visitCheckCast(T_CLASS) + unboxInPrimitiveMode() } internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) @@ -290,11 +337,23 @@ internal class AsmBuilder internal constructor( ) } + private val NUMBER_CONVERTER_METHODS by lazy { + mapOf( + "B" to "byteValue", + "S" to "shortValue", + "I" to "intValue", + "J" to "longValue", + "F" to "floatValue", + "D" to "doubleValue" + ) + } + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/internal/AsmCompiledExpression" + internal const val NUMBER_CLASS = "java/lang/Number" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt similarity index 90% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a39dba6b1..f279c94f4 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -24,9 +24,9 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri val sig = buildString { append('(') - repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } + repeat(arity) { append(primitiveTypeSig) } append(')') - append("L${AsmBuilder.OBJECT_CLASS};") + append(primitiveTypeReturnSig) } invokeAlgebraOperation( diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index a6cb1e247..294da32e9 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,10 +1,37 @@ package scietifik.kmath.asm -// -//class TestAsmAlgebras { +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmAlgebras { + @Test + fun space() { + val res = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to 2.toByte()) + + assertEquals(16, res) + } + // @Test // fun space() { -// val res = ByteRing.asmInRing { +// val res = ByteRing.asm { // binaryOperation( // "+", // @@ -63,4 +90,4 @@ package scietifik.kmath.asm // // assertEquals(3.0, res) // } -//} +} From 9a3709624d5b415f5ae0bc9c7c3208a5dccb651e Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 24 Jun 2020 15:54:17 +0700 Subject: [PATCH 2/3] Use hashMap instead of map --- .../kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 77a4ee9e9..95c84ca51 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -387,7 +387,7 @@ internal class AsmBuilder internal constructor( * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ private val SIGNATURE_LETTERS: Map, String> by lazy { - mapOf( + hashMapOf( java.lang.Byte::class.java to "B", java.lang.Short::class.java to "S", java.lang.Integer::class.java to "I", @@ -397,8 +397,8 @@ internal class AsmBuilder internal constructor( ) } - private val NUMBER_CONVERTER_METHODS by lazy { - mapOf( + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( "B" to "byteValue", "S" to "shortValue", "I" to "intValue", From 02f42ee56a8047af9ab7846b5e0b271084c9192a Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 24 Jun 2020 20:55:48 +0700 Subject: [PATCH 3/3] Eliminate bridging --- .../kotlin/scientifik/kmath/asm/asm.kt | 63 ++- .../kmath/asm/internal/AsmBuilder.kt | 366 ++++++++++-------- .../asm/internal/AsmCompiledExpression.kt | 1 - .../scientifik/kmath/asm/internal/classes.kt | 7 + .../kmath/asm/internal/methodVisitors.kt | 59 +-- .../kmath/asm/internal/specialization.kt | 34 +- .../scietifik/kmath/asm/TestAsmAlgebras.kt | 20 +- 7 files changed, 274 insertions(+), 276 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index a09a7ebe3..a3af80ccd 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,8 +1,9 @@ package scientifik.kmath.asm +import org.objectweb.asm.Type import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.buildExpectationStack import scientifik.kmath.asm.internal.buildName -import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.ast.MST import scientifik.kmath.ast.MSTExpression @@ -18,6 +19,7 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< fun AsmBuilder.visit(node: MST) { when (node) { is MST.Symbolic -> loadVariable(node.value) + is MST.Numeric -> { val constant = if (algebra is NumericAlgebra) algebra.number(node.value) @@ -26,48 +28,49 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< loadTConstant(constant) } + is MST.Unary -> { loadAlgebra() - - if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation) - + if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation) visit(node.value) - if (!tryInvokeSpecific(algebra, node.operation, 1)) { - invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "unaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } + if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = "unaryOperation", + + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + AsmBuilder.OBJECT_TYPE + ), + + tArity = 1 + ) } is MST.Binary -> { loadAlgebra() - - if (!hasSpecific(algebra, node.operation, 2)) - loadStringConstant(node.operation) - + if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation) visit(node.left) visit(node.right) - if (!tryInvokeSpecific(algebra, node.operation, 2)) { + if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation( + owner = AsmBuilder.ALGEBRA_TYPE.internalName, + method = "binaryOperation", - invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_CLASS, - method = "binaryOperation", - descriptor = "(L${AsmBuilder.STRING_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};" + - "L${AsmBuilder.OBJECT_CLASS};)" + - "L${AsmBuilder.OBJECT_CLASS};" - ) - } + descriptor = Type.getMethodDescriptor( + AsmBuilder.OBJECT_TYPE, + AsmBuilder.STRING_TYPE, + AsmBuilder.OBJECT_TYPE, + AsmBuilder.OBJECT_TYPE + ), + + tArity = 2 + ) } } } - return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() + return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() } /** @@ -79,7 +82,3 @@ inline fun Algebra.expression(mst: MST): Expression = ms * Optimize performance of an [MSTExpression] using ASM codegen */ inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) - -fun main() { - RealField.mstInField { symbol("x") + 2 }.compile() -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 95c84ca51..2c9cd48d9 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -1,14 +1,17 @@ package scientifik.kmath.asm.internal -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Label -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes +import org.objectweb.asm.* +import org.objectweb.asm.Opcodes.AALOAD +import org.objectweb.asm.Opcodes.RETURN +import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader +import scientifik.kmath.ast.MST import scientifik.kmath.operations.Algebra +import java.util.* +import kotlin.reflect.KClass /** - * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. + * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @param T the type of AsmExpression to unwrap. @@ -16,7 +19,7 @@ import scientifik.kmath.operations.Algebra * @param className the unique class name of new loaded class. */ internal class AsmBuilder internal constructor( - private val classOfT: Class<*>, + private val classOfT: KClass<*>, private val algebra: Algebra, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit @@ -31,16 +34,16 @@ internal class AsmBuilder internal constructor( /** * The instance of [ClassLoader] used by this builder. */ - private val classLoader: ClassLoader = - ClassLoader(javaClass.classLoader) + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") - private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + private val T_ALGEBRA_TYPE: Type = algebra::class.asm @Suppress("PrivatePropertyName") - private val T_CLASS: String = classOfT.name.replace('.', '/') + internal val T_TYPE: Type = classOfT.asm - private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + @Suppress("PrivatePropertyName") + private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. @@ -60,11 +63,16 @@ internal class AsmBuilder internal constructor( /** * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. */ - private lateinit var invokeMethodVisitor: MethodVisitor - private var primitiveMode = false - internal var primitiveType = OBJECT_CLASS - internal var primitiveTypeSig = "L$OBJECT_CLASS;" - internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" + private lateinit var invokeMethodVisitor: InstructionAdapter + internal var primitiveMode = false + + @Suppress("PropertyName") + internal var PRIMITIVE_MASK: Type = OBJECT_TYPE + + @Suppress("PropertyName") + internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE + private val typeStack = Stack() + internal val expectationStack = Stack().apply { push(T_TYPE) } /** * The cache of [AsmCompiledExpression] subclass built by this builder. @@ -80,89 +88,85 @@ internal class AsmBuilder internal constructor( fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } - if (classOfT in SIGNATURE_LETTERS) { + if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { primitiveMode = true - primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT) - primitiveType = SIGNATURE_LETTERS.getValue(classOfT) - primitiveTypeReturnSig = "L$T_CLASS;" + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) + PRIMITIVE_MASK_BOXED = T_TYPE } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - slashesClassName, - "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + CLASS_TYPE.internalName, + "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", + ASM_COMPILED_EXPRESSION_TYPE.internalName, arrayOf() ) visitMethod( - access = Opcodes.ACC_PUBLIC, - name = "", - descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", - signature = null, - exceptions = null - ) { + Opcodes.ACC_PUBLIC, + "", + Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + null, + null + ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 val l0 = Label() visitLabel(l0) - visitLoadObjectVar(thisVar) - visitLoadObjectVar(algebraVar) - visitLoadObjectVar(constantsVar) + load(thisVar, CLASS_TYPE) + load(algebraVar, ALGEBRA_TYPE) + load(constantsVar, OBJECT_ARRAY_TYPE) - visitInvokeSpecial( - FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + invokespecial( + ASM_COMPILED_EXPRESSION_TYPE.internalName, "", - "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" + Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + false ) val l1 = Label() visitLabel(l1) - visitReturn() + visitInsn(RETURN) val l2 = Label() visitLabel(l2) - visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) visitLocalVariable( "algebra", - "L$ALGEBRA_CLASS;", - "L$ALGEBRA_CLASS;", + ALGEBRA_TYPE.descriptor, + "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", l0, l2, algebraVar ) - visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) visitMaxs(0, 3) visitEnd() } visitMethod( - access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, - name = "invoke", - descriptor = "(L$MAP_CLASS;)L$T_CLASS;", - signature = "(L$MAP_CLASS;)L$T_CLASS;", - exceptions = null - ) { + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + Type.getMethodDescriptor(T_TYPE, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + null + ).instructionAdapter { invokeMethodVisitor = this visitCode() val l0 = Label() visitLabel(l0) invokeLabel0Visitor() - - if (primitiveMode) - boxPrimitiveToNumber() - - visitReturnObject() + areturn(T_TYPE) val l1 = Label() visitLabel(l1) visitLocalVariable( "this", - "L$slashesClassName;", + CLASS_TYPE.descriptor, null, l0, l1, @@ -171,8 +175,8 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", l0, l1, invokeArgumentsVar @@ -183,28 +187,28 @@ internal class AsmBuilder internal constructor( } visitMethod( - access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, - name = "invoke", - descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", - signature = null, - exceptions = null - ) { + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), + null, + null + ).instructionAdapter { val thisVar = 0 val argumentsVar = 1 visitCode() val l0 = Label() visitLabel(l0) - visitLoadObjectVar(thisVar) - visitLoadObjectVar(argumentsVar) - visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") - visitReturnObject() + load(thisVar, OBJECT_TYPE) + load(argumentsVar, MAP_TYPE) + invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) + areturn(T_TYPE) val l1 = Label() visitLabel(l1) visitLocalVariable( "this", - "L$slashesClassName;", - T_CLASS, + CLASS_TYPE.descriptor, + null, l0, l1, thisVar @@ -231,116 +235,111 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (primitiveMode) { - loadNumberConstant(value as Number) + if (classOfT.java in INLINABLE_NUMBERS) { + val expectedType = expectationStack.pop()!! + val mustBeBoxed = expectedType.sort == Type.OBJECT + loadNumberConstant(value as Number, mustBeBoxed) + if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) return } - if (classOfT in INLINABLE_NUMBERS) { - loadNumberConstant(value as Number) - invokeMethodVisitor.visitCheckCast(T_CLASS) - return - } - - loadConstant(value as Any, T_CLASS) + loadConstant(value as Any, T_TYPE) } - /** - * Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. - */ - private fun unboxInPrimitiveMode() { - if (!primitiveMode) - return + private fun box(): Unit = invokeMethodVisitor.invokestatic( + T_TYPE.internalName, + "valueOf", + Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + false + ) - unboxNumberToPrimitive() - } + private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( + NUMBER_TYPE.internalName, + NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK), + Type.getMethodDescriptor(PRIMITIVE_MASK), + false + ) - private fun boxPrimitiveToNumber() { - invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;") - } - - private fun unboxNumberToPrimitive() { - invokeMethodVisitor.visitInvokeVirtual( - owner = NUMBER_CLASS, - name = NUMBER_CONVERTER_METHODS.getValue(primitiveType), - descriptor = "()${primitiveTypeSig}" - ) - } - - private fun loadConstant(value: Any, type: String) { + private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - - invokeMethodVisitor.run { - loadThis() - visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") - visitLdcOrIntConstant(idx) - visitGetObjectArrayElement() - invokeMethodVisitor.visitCheckCast(type) - } + loadThis() + getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive + * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ - private fun loadNumberConstant(value: Number) { - val clazz = value.javaClass - val c = clazz.name.replace('.', '/') - val sigLetter = SIGNATURE_LETTERS[clazz] + private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { + val boxed = value::class.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] - if (sigLetter != null) { - when (value) { - is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) - is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) - is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) - is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) - is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) - is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) - else -> invokeMethodVisitor.visitLdcInsn(value) + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) } - if (!primitiveMode) - invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + if (mustBeBoxed) { + box() + invokeMethodVisitor.checkcast(T_TYPE) + } return } - loadConstant(value, c) + loadConstant(value, boxed) + if (!mustBeBoxed) unbox() + else invokeMethodVisitor.checkcast(T_TYPE) } /** * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { - visitLoadObjectVar(invokeArgumentsVar) + load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) if (defaultValue != null) { - visitLdcInsn(name) + loadStringConstant(name) loadTConstant(defaultValue) - visitInvokeInterface( - owner = MAP_CLASS, - name = "getOrDefault", - descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" + invokeinterface( + MAP_TYPE.internalName, + "getOrDefault", + Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.visitCheckCast(T_CLASS) + invokeMethodVisitor.checkcast(T_TYPE) return } - visitLdcInsn(name) + loadStringConstant(name) - visitInvokeInterface( - owner = MAP_CLASS, - name = "get", - descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" + invokeinterface( + MAP_TYPE.internalName, + "get", + Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.visitCheckCast(T_CLASS) - unboxInPrimitiveMode() + invokeMethodVisitor.checkcast(T_TYPE) + val expectedType = expectationStack.pop()!! + + if (expectedType.sort == Type.OBJECT) + typeStack.push(T_TYPE) + else { + unbox() + typeStack.push(PRIMITIVE_MASK) + } } /** @@ -349,13 +348,10 @@ internal class AsmBuilder internal constructor( internal fun loadAlgebra() { loadThis() - invokeMethodVisitor.visitGetField( - owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, - name = "algebra", - descriptor = "L$ALGEBRA_CLASS;" - ) - - invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) + invokeMethodVisitor.run { + getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) + checkcast(T_ALGEBRA_TYPE) + } } /** @@ -369,42 +365,70 @@ internal class AsmBuilder internal constructor( owner: String, method: String, descriptor: String, - opcode: Int = Opcodes.INVOKEINTERFACE, - isInterface: Boolean = true + tArity: Int, + opcode: Int = Opcodes.INVOKEINTERFACE ) { - invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) - invokeMethodVisitor.visitCheckCast(T_CLASS) - unboxInPrimitiveMode() + repeat(tArity) { typeStack.pop() } + + invokeMethodVisitor.visitMethodInsn( + opcode, + owner, + method, + descriptor, + opcode == Opcodes.INVOKEINTERFACE + ) + + invokeMethodVisitor.checkcast(T_TYPE) + val isLastExpr = expectationStack.size == 1 + val expectedType = expectationStack.pop()!! + + if (expectedType.sort == Type.OBJECT || isLastExpr) + typeStack.push(T_TYPE) + else { + unbox() + typeStack.push(PRIMITIVE_MASK) + } } /** * Writes a LDC Instruction with string constant provided. */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) + internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) internal companion object { /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, String> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" + java.lang.Byte::class.java to Type.BYTE_TYPE, + java.lang.Short::class.java to Type.SHORT_TYPE, + java.lang.Integer::class.java to Type.INT_TYPE, + java.lang.Long::class.java to Type.LONG_TYPE, + java.lang.Float::class.java to Type.FLOAT_TYPE, + java.lang.Double::class.java to Type.DOUBLE_TYPE ) } - private val NUMBER_CONVERTER_METHODS: Map by lazy { + private val BOXED_TO_PRIMITIVES: Map by lazy { hashMapOf( - "B" to "byteValue", - "S" to "shortValue", - "I" to "intValue", - "J" to "longValue", - "F" to "floatValue", - "D" to "doubleValue" + java.lang.Byte::class.asm to Type.BYTE_TYPE, + java.lang.Short::class.asm to Type.SHORT_TYPE, + java.lang.Integer::class.asm to Type.INT_TYPE, + java.lang.Long::class.asm to Type.LONG_TYPE, + java.lang.Float::class.asm to Type.FLOAT_TYPE, + java.lang.Double::class.asm to Type.DOUBLE_TYPE + ) + } + + private val NUMBER_CONVERTER_METHODS: Map by lazy { + hashMapOf( + Type.BYTE_TYPE to "byteValue", + Type.SHORT_TYPE to "shortValue", + Type.INT_TYPE to "intValue", + Type.LONG_TYPE to "longValue", + Type.FLOAT_TYPE to "floatValue", + Type.DOUBLE_TYPE to "doubleValue" ) } @@ -412,14 +436,14 @@ internal class AsmBuilder internal constructor( * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm + internal val NUMBER_TYPE: Type = java.lang.Number::class.asm + internal val MAP_TYPE: Type = java.util.Map::class.asm + internal val OBJECT_TYPE: Type = java.lang.Object::class.asm - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/internal/AsmCompiledExpression" - - internal const val NUMBER_CLASS = "java/lang/Number" - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val STRING_CLASS = "java/lang/String" + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") + internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm + internal val ALGEBRA_TYPE: Type = Algebra::class.asm + internal val STRING_TYPE: Type = java.lang.String::class.asm } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt index a5236927c..7c4a9fc99 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -16,4 +16,3 @@ internal abstract class AsmCompiledExpression internal constructor( ) : Expression { abstract override fun invoke(arguments: Map): T } - diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt new file mode 100644 index 000000000..dc0b35531 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classes.kt @@ -0,0 +1,7 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Type +import kotlin.reflect.KClass + +internal val KClass<*>.asm: Type + get() = Type.getType(java) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 6ecce8833..7b0d346b7 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -1,60 +1,9 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes.* +import org.objectweb.asm.commons.InstructionAdapter -internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { - -1 -> visitInsn(ICONST_M1) - 0 -> visitInsn(ICONST_0) - 1 -> visitInsn(ICONST_1) - 2 -> visitInsn(ICONST_2) - 3 -> visitInsn(ICONST_3) - 4 -> visitInsn(ICONST_4) - 5 -> visitInsn(ICONST_5) - in -128..127 -> visitIntInsn(BIPUSH, value) - in -32768..32767 -> visitIntInsn(SIPUSH, value) - else -> visitLdcInsn(value) -} +fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { - 0.0 -> visitInsn(DCONST_0) - 1.0 -> visitInsn(DCONST_1) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) { - 0L -> visitInsn(LCONST_0) - 1L -> visitInsn(LCONST_1) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { - 0f -> visitInsn(FCONST_0) - 1f -> visitInsn(FCONST_1) - 2f -> visitInsn(FCONST_2) - else -> visitLdcInsn(value) -} - -internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true) - -internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false) - -internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false) - -internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit = - visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false) - -internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type) - -internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit = - visitFieldInsn(GETFIELD, owner, name, descriptor) - -internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`) - -internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD) - -internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN) -internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN) +fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = + instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 51d056f8c..2e15a1a93 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -1,6 +1,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes +import org.objectweb.asm.Type import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map by lazy { @@ -12,17 +13,19 @@ private val methodNameAdapters: Map by lazy { } /** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity]. + * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds + * type expectation stack for needed arity. * * @return `true` if contains, else `false`. */ -internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } - ?: return false + val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + repeat(arity) { expectationStack.push(t) } - return true + return hasSpecific } /** @@ -34,22 +37,23 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } - ?: return false + val method = + context.javaClass.methods.find { + var suitableSignature = it.name == aName && it.parameters.size == arity + + if (primitiveMode && it.isBridge) + suitableSignature = false + + suitableSignature + } ?: return false val owner = context::class.java.name.replace('.', '/') - val sig = buildString { - append('(') - repeat(arity) { append(primitiveTypeSig) } - append(')') - append(primitiveTypeReturnSig) - } - invokeAlgebraOperation( owner = owner, method = aName, - descriptor = sig, + descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }), + tArity = arity, opcode = Opcodes.INVOKEVIRTUAL ) diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 294da32e9..ce5b7cbf9 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -10,7 +10,23 @@ import kotlin.test.assertEquals class TestAsmAlgebras { @Test fun space() { - val res = ByteRing.mstInSpace { + val res1 = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }("x" to 2.toByte()) + + val res2 = ByteRing.mstInSpace { binaryOperation( "+", @@ -26,7 +42,7 @@ class TestAsmAlgebras { ) + symbol("x") + zero }.compile()("x" to 2.toByte()) - assertEquals(16, res) + assertEquals(res1, res2) } // @Test