Eliminate bridging

This commit is contained in:
Iaroslav 2020-06-24 20:55:48 +07:00
parent 9a3709624d
commit 02f42ee56a
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
7 changed files with 274 additions and 276 deletions

View File

@ -1,8 +1,9 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import org.objectweb.asm.Type
import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.buildExpectationStack
import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.buildName
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MSTExpression import scientifik.kmath.ast.MSTExpression
@ -18,6 +19,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> { is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) val constant = if (algebra is NumericAlgebra<T>)
algebra.number(node.value) algebra.number(node.value)
@ -26,48 +28,49 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
loadTConstant(constant) loadTConstant(constant)
} }
is MST.Unary -> { is MST.Unary -> {
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation)
if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
visit(node.value) visit(node.value)
if (!tryInvokeSpecific(algebra, node.operation, 1)) { if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation(
invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_TYPE.internalName,
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation", method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" + descriptor = Type.getMethodDescriptor(
"L${AsmBuilder.OBJECT_CLASS};" AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
AsmBuilder.OBJECT_TYPE
),
tArity = 1
) )
} }
}
is MST.Binary -> { is MST.Binary -> {
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
if (!hasSpecific(algebra, node.operation, 2))
loadStringConstant(node.operation)
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) { if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation", method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" + descriptor = Type.getMethodDescriptor(
"L${AsmBuilder.OBJECT_CLASS};)" + AsmBuilder.OBJECT_TYPE,
"L${AsmBuilder.OBJECT_CLASS};" AsmBuilder.STRING_TYPE,
AsmBuilder.OBJECT_TYPE,
AsmBuilder.OBJECT_TYPE
),
tArity = 2
) )
} }
} }
} }
}
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
} }
/** /**
@ -79,7 +82,3 @@ inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = ms
* Optimize performance of an [MSTExpression] using ASM codegen * Optimize performance of an [MSTExpression] using ASM codegen
*/ */
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra) inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
fun main() {
RealField.mstInField { symbol("x") + 2 }.compile()
}

View File

@ -1,14 +1,17 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.ClassWriter import org.objectweb.asm.*
import org.objectweb.asm.Label import org.objectweb.asm.Opcodes.AALOAD
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.RETURN
import org.objectweb.asm.Opcodes import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.ast.MST
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import java.util.*
import kotlin.reflect.KClass
/** /**
* ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
* *
* @param T the type of AsmExpression to unwrap. * @param T the type of AsmExpression to unwrap.
@ -16,7 +19,7 @@ import scientifik.kmath.operations.Algebra
* @param className the unique class name of new loaded class. * @param className the unique class name of new loaded class.
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T> internal constructor(
private val classOfT: Class<*>, private val classOfT: KClass<*>,
private val algebra: Algebra<T>, private val algebra: Algebra<T>,
private val className: String, private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
@ -31,16 +34,16 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* The instance of [ClassLoader] used by this builder. * The instance of [ClassLoader] used by this builder.
*/ */
private val classLoader: ClassLoader = private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
ClassLoader(javaClass.classLoader)
@Suppress("PrivatePropertyName") @Suppress("PrivatePropertyName")
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') private val T_ALGEBRA_TYPE: Type = algebra::class.asm
@Suppress("PrivatePropertyName") @Suppress("PrivatePropertyName")
private val T_CLASS: String = classOfT.name.replace('.', '/') internal val T_TYPE: Type = classOfT.asm
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') @Suppress("PrivatePropertyName")
private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
/** /**
* Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass.
@ -60,11 +63,16 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Method visitor of `invoke` method of [AsmCompiledExpression] subclass. * Method visitor of `invoke` method of [AsmCompiledExpression] subclass.
*/ */
private lateinit var invokeMethodVisitor: MethodVisitor private lateinit var invokeMethodVisitor: InstructionAdapter
private var primitiveMode = false internal var primitiveMode = false
internal var primitiveType = OBJECT_CLASS
internal var primitiveTypeSig = "L$OBJECT_CLASS;" @Suppress("PropertyName")
internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" internal var PRIMITIVE_MASK: Type = OBJECT_TYPE
@Suppress("PropertyName")
internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE
private val typeStack = Stack<Type>()
internal val expectationStack = Stack<Type>().apply { push(T_TYPE) }
/** /**
* The cache of [AsmCompiledExpression] subclass built by this builder. * The cache of [AsmCompiledExpression] subclass built by this builder.
@ -80,89 +88,85 @@ internal class AsmBuilder<T> internal constructor(
fun getInstance(): AsmCompiledExpression<T> { fun getInstance(): AsmCompiledExpression<T> {
generatedInstance?.let { return it } generatedInstance?.let { return it }
if (classOfT in SIGNATURE_LETTERS) { if (SIGNATURE_LETTERS.containsKey(classOfT.java)) {
primitiveMode = true primitiveMode = true
primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT) PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java)
primitiveType = SIGNATURE_LETTERS.getValue(classOfT) PRIMITIVE_MASK_BOXED = T_TYPE
primitiveTypeReturnSig = "L$T_CLASS;"
} }
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit( visit(
Opcodes.V1_8, Opcodes.V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
slashesClassName, CLASS_TYPE.internalName,
"L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;", "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;",
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, ASM_COMPILED_EXPRESSION_TYPE.internalName,
arrayOf() arrayOf()
) )
visitMethod( visitMethod(
access = Opcodes.ACC_PUBLIC, Opcodes.ACC_PUBLIC,
name = "<init>", "<init>",
descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE),
signature = null, null,
exceptions = null null
) { ).instructionAdapter {
val thisVar = 0 val thisVar = 0
val algebraVar = 1 val algebraVar = 1
val constantsVar = 2 val constantsVar = 2
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
visitLoadObjectVar(thisVar) load(thisVar, CLASS_TYPE)
visitLoadObjectVar(algebraVar) load(algebraVar, ALGEBRA_TYPE)
visitLoadObjectVar(constantsVar) load(constantsVar, OBJECT_ARRAY_TYPE)
visitInvokeSpecial( invokespecial(
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, ASM_COMPILED_EXPRESSION_TYPE.internalName,
"<init>", "<init>",
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE),
false
) )
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitReturn() visitInsn(RETURN)
val l2 = Label() val l2 = Label()
visitLabel(l2) visitLabel(l2)
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar)
visitLocalVariable( visitLocalVariable(
"algebra", "algebra",
"L$ALGEBRA_CLASS;", ALGEBRA_TYPE.descriptor,
"L$ALGEBRA_CLASS<L$T_CLASS;>;", "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;",
l0, l0,
l2, l2,
algebraVar algebraVar
) )
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar)
visitMaxs(0, 3) visitMaxs(0, 3)
visitEnd() visitEnd()
} }
visitMethod( visitMethod(
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
name = "invoke", "invoke",
descriptor = "(L$MAP_CLASS;)L$T_CLASS;", Type.getMethodDescriptor(T_TYPE, MAP_TYPE),
signature = "(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;", "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}",
exceptions = null null
) { ).instructionAdapter {
invokeMethodVisitor = this invokeMethodVisitor = this
visitCode() visitCode()
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
invokeLabel0Visitor() invokeLabel0Visitor()
areturn(T_TYPE)
if (primitiveMode)
boxPrimitiveToNumber()
visitReturnObject()
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitLocalVariable( visitLocalVariable(
"this", "this",
"L$slashesClassName;", CLASS_TYPE.descriptor,
null, null,
l0, l0,
l1, l1,
@ -171,8 +175,8 @@ internal class AsmBuilder<T> internal constructor(
visitLocalVariable( visitLocalVariable(
"arguments", "arguments",
"L$MAP_CLASS;", MAP_TYPE.descriptor,
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;", "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;",
l0, l0,
l1, l1,
invokeArgumentsVar invokeArgumentsVar
@ -183,28 +187,28 @@ internal class AsmBuilder<T> internal constructor(
} }
visitMethod( visitMethod(
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
name = "invoke", "invoke",
descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
signature = null, null,
exceptions = null null
) { ).instructionAdapter {
val thisVar = 0 val thisVar = 0
val argumentsVar = 1 val argumentsVar = 1
visitCode() visitCode()
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
visitLoadObjectVar(thisVar) load(thisVar, OBJECT_TYPE)
visitLoadObjectVar(argumentsVar) load(argumentsVar, MAP_TYPE)
visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false)
visitReturnObject() areturn(T_TYPE)
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitLocalVariable( visitLocalVariable(
"this", "this",
"L$slashesClassName;", CLASS_TYPE.descriptor,
T_CLASS, null,
l0, l0,
l1, l1,
thisVar thisVar
@ -231,116 +235,111 @@ internal class AsmBuilder<T> internal constructor(
* Loads a constant from * Loads a constant from
*/ */
internal fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (primitiveMode) { if (classOfT.java in INLINABLE_NUMBERS) {
loadNumberConstant(value as Number) val expectedType = expectationStack.pop()!!
val mustBeBoxed = expectedType.sort == Type.OBJECT
loadNumberConstant(value as Number, mustBeBoxed)
if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK)
return return
} }
if (classOfT in INLINABLE_NUMBERS) { loadConstant(value as Any, T_TYPE)
loadNumberConstant(value as Number)
invokeMethodVisitor.visitCheckCast(T_CLASS)
return
} }
loadConstant(value as Any, T_CLASS) private fun box(): Unit = invokeMethodVisitor.invokestatic(
} T_TYPE.internalName,
"valueOf",
/** Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK),
* Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. false
*/
private fun unboxInPrimitiveMode() {
if (!primitiveMode)
return
unboxNumberToPrimitive()
}
private fun boxPrimitiveToNumber() {
invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;")
}
private fun unboxNumberToPrimitive() {
invokeMethodVisitor.visitInvokeVirtual(
owner = NUMBER_CLASS,
name = NUMBER_CONVERTER_METHODS.getValue(primitiveType),
descriptor = "()${primitiveTypeSig}"
) )
}
private fun loadConstant(value: Any, type: String) { private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
NUMBER_TYPE.internalName,
NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK),
Type.getMethodDescriptor(PRIMITIVE_MASK),
false
)
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
invokeMethodVisitor.run {
loadThis() loadThis()
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
visitLdcOrIntConstant(idx) iconst(idx)
visitGetObjectArrayElement() visitInsn(AALOAD)
invokeMethodVisitor.visitCheckCast(type) checkcast(type)
}
} }
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE)
/** /**
* Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
* from it). * from it).
*/ */
private fun loadNumberConstant(value: Number) { private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
val clazz = value.javaClass val boxed = value::class.asm
val c = clazz.name.replace('.', '/') val primitive = BOXED_TO_PRIMITIVES[boxed]
val sigLetter = SIGNATURE_LETTERS[clazz]
if (sigLetter != null) { if (primitive != null) {
when (value) { when (primitive) {
is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
else -> invokeMethodVisitor.visitLdcInsn(value)
} }
if (!primitiveMode) if (mustBeBoxed) {
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") box()
invokeMethodVisitor.checkcast(T_TYPE)
}
return return
} }
loadConstant(value, c) loadConstant(value, boxed)
if (!mustBeBoxed) unbox()
else invokeMethodVisitor.checkcast(T_TYPE)
} }
/** /**
* Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided.
*/ */
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
visitLoadObjectVar(invokeArgumentsVar) load(invokeArgumentsVar, OBJECT_ARRAY_TYPE)
if (defaultValue != null) { if (defaultValue != null) {
visitLdcInsn(name) loadStringConstant(name)
loadTConstant(defaultValue) loadTConstant(defaultValue)
visitInvokeInterface( invokeinterface(
owner = MAP_CLASS, MAP_TYPE.internalName,
name = "getOrDefault", "getOrDefault",
descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE)
) )
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.checkcast(T_TYPE)
return return
} }
visitLdcInsn(name) loadStringConstant(name)
visitInvokeInterface( invokeinterface(
owner = MAP_CLASS, MAP_TYPE.internalName,
name = "get", "get",
descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
) )
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.checkcast(T_TYPE)
unboxInPrimitiveMode() val expectedType = expectationStack.pop()!!
if (expectedType.sort == Type.OBJECT)
typeStack.push(T_TYPE)
else {
unbox()
typeStack.push(PRIMITIVE_MASK)
}
} }
/** /**
@ -349,13 +348,10 @@ internal class AsmBuilder<T> internal constructor(
internal fun loadAlgebra() { internal fun loadAlgebra() {
loadThis() loadThis()
invokeMethodVisitor.visitGetField( invokeMethodVisitor.run {
owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor)
name = "algebra", checkcast(T_ALGEBRA_TYPE)
descriptor = "L$ALGEBRA_CLASS;" }
)
invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS)
} }
/** /**
@ -369,42 +365,70 @@ internal class AsmBuilder<T> internal constructor(
owner: String, owner: String,
method: String, method: String,
descriptor: String, descriptor: String,
opcode: Int = Opcodes.INVOKEINTERFACE, tArity: Int,
isInterface: Boolean = true opcode: Int = Opcodes.INVOKEINTERFACE
) { ) {
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) repeat(tArity) { typeStack.pop() }
invokeMethodVisitor.visitCheckCast(T_CLASS)
unboxInPrimitiveMode() invokeMethodVisitor.visitMethodInsn(
opcode,
owner,
method,
descriptor,
opcode == Opcodes.INVOKEINTERFACE
)
invokeMethodVisitor.checkcast(T_TYPE)
val isLastExpr = expectationStack.size == 1
val expectedType = expectationStack.pop()!!
if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(T_TYPE)
else {
unbox()
typeStack.push(PRIMITIVE_MASK)
}
} }
/** /**
* Writes a LDC Instruction with string constant provided. * Writes a LDC Instruction with string constant provided.
*/ */
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
internal companion object { internal companion object {
/** /**
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention. * Maps JVM primitive numbers boxed types to their letters of JVM signature convention.
*/ */
private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy { private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
hashMapOf( hashMapOf(
java.lang.Byte::class.java to "B", java.lang.Byte::class.java to Type.BYTE_TYPE,
java.lang.Short::class.java to "S", java.lang.Short::class.java to Type.SHORT_TYPE,
java.lang.Integer::class.java to "I", java.lang.Integer::class.java to Type.INT_TYPE,
java.lang.Long::class.java to "J", java.lang.Long::class.java to Type.LONG_TYPE,
java.lang.Float::class.java to "F", java.lang.Float::class.java to Type.FLOAT_TYPE,
java.lang.Double::class.java to "D" java.lang.Double::class.java to Type.DOUBLE_TYPE
) )
} }
private val NUMBER_CONVERTER_METHODS: Map<String, String> by lazy { private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
hashMapOf( hashMapOf(
"B" to "byteValue", java.lang.Byte::class.asm to Type.BYTE_TYPE,
"S" to "shortValue", java.lang.Short::class.asm to Type.SHORT_TYPE,
"I" to "intValue", java.lang.Integer::class.asm to Type.INT_TYPE,
"J" to "longValue", java.lang.Long::class.asm to Type.LONG_TYPE,
"F" to "floatValue", java.lang.Float::class.asm to Type.FLOAT_TYPE,
"D" to "doubleValue" java.lang.Double::class.asm to Type.DOUBLE_TYPE
)
}
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
hashMapOf(
Type.BYTE_TYPE to "byteValue",
Type.SHORT_TYPE to "shortValue",
Type.INT_TYPE to "intValue",
Type.LONG_TYPE to "longValue",
Type.FLOAT_TYPE to "floatValue",
Type.DOUBLE_TYPE to "doubleValue"
) )
} }
@ -412,14 +436,14 @@ internal class AsmBuilder<T> internal constructor(
* Provides boxed number types values of which can be stored in JVM bytecode constant pool. * Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/ */
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys } private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm
internal val NUMBER_TYPE: Type = java.lang.Number::class.asm
internal val MAP_TYPE: Type = java.util.Map::class.asm
internal val OBJECT_TYPE: Type = java.lang.Object::class.asm
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
"scientifik/kmath/asm/internal/AsmCompiledExpression" internal val OBJECT_ARRAY_TYPE: Type = Array<java.lang.Object>::class.asm
internal val ALGEBRA_TYPE: Type = Algebra::class.asm
internal const val NUMBER_CLASS = "java/lang/Number" internal val STRING_TYPE: Type = java.lang.String::class.asm
internal const val MAP_CLASS = "java/util/Map"
internal const val OBJECT_CLASS = "java/lang/Object"
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
internal const val STRING_CLASS = "java/lang/String"
} }
} }

View File

@ -16,4 +16,3 @@ internal abstract class AsmCompiledExpression<T> internal constructor(
) : Expression<T> { ) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T abstract override fun invoke(arguments: Map<String, T>): T
} }

View File

@ -0,0 +1,7 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Type
import kotlin.reflect.KClass
internal val KClass<*>.asm: Type
get() = Type.getType(java)

View File

@ -1,60 +1,9 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter
internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
-1 -> visitInsn(ICONST_M1)
0 -> visitInsn(ICONST_0)
1 -> visitInsn(ICONST_1)
2 -> visitInsn(ICONST_2)
3 -> visitInsn(ICONST_3)
4 -> visitInsn(ICONST_4)
5 -> visitInsn(ICONST_5)
in -128..127 -> visitIntInsn(BIPUSH, value)
in -32768..32767 -> visitIntInsn(SIPUSH, value)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
0.0 -> visitInsn(DCONST_0) instructionAdapter().apply(block)
1.0 -> visitInsn(DCONST_1)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) {
0L -> visitInsn(LCONST_0)
1L -> visitInsn(LCONST_1)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) {
0f -> visitInsn(FCONST_0)
1f -> visitInsn(FCONST_1)
2f -> visitInsn(FCONST_2)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true)
internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false)
internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false)
internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false)
internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type)
internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit =
visitFieldInsn(GETFIELD, owner, name, descriptor)
internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`)
internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD)
internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN)
internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN)

View File

@ -1,6 +1,7 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> by lazy { private val methodNameAdapters: Map<String, String> by lazy {
@ -12,17 +13,19 @@ private val methodNameAdapters: Map<String, String> by lazy {
} }
/** /**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity]. * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
* type expectation stack for needed arity.
* *
* @return `true` if contains, else `false`. * @return `true` if contains, else `false`.
*/ */
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name val aName = methodNameAdapters[name] ?: name
context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null
?: return false val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE
repeat(arity) { expectationStack.push(t) }
return true return hasSpecific
} }
/** /**
@ -34,22 +37,23 @@ internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boo
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name val aName = methodNameAdapters[name] ?: name
context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } val method =
?: return false context.javaClass.methods.find {
var suitableSignature = it.name == aName && it.parameters.size == arity
if (primitiveMode && it.isBridge)
suitableSignature = false
suitableSignature
} ?: return false
val owner = context::class.java.name.replace('.', '/') val owner = context::class.java.name.replace('.', '/')
val sig = buildString {
append('(')
repeat(arity) { append(primitiveTypeSig) }
append(')')
append(primitiveTypeReturnSig)
}
invokeAlgebraOperation( invokeAlgebraOperation(
owner = owner, owner = owner,
method = aName, method = aName,
descriptor = sig, descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }),
tArity = arity,
opcode = Opcodes.INVOKEVIRTUAL opcode = Opcodes.INVOKEVIRTUAL
) )

View File

@ -10,7 +10,23 @@ import kotlin.test.assertEquals
class TestAsmAlgebras { class TestAsmAlgebras {
@Test @Test
fun space() { fun space() {
val res = ByteRing.mstInSpace { val res1 = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + symbol("x") + zero
}("x" to 2.toByte())
val res2 = ByteRing.mstInSpace {
binaryOperation( binaryOperation(
"+", "+",
@ -26,7 +42,7 @@ class TestAsmAlgebras {
) + symbol("x") + zero ) + symbol("x") + zero
}.compile()("x" to 2.toByte()) }.compile()("x" to 2.toByte())
assertEquals(16, res) assertEquals(res1, res2)
} }
// @Test // @Test