Dev #127

Merged
altavir merged 214 commits from dev into master 2020-08-11 08:33:21 +03:00
10 changed files with 408 additions and 305 deletions
Showing only changes of commit 7ddab0224a - Show all commits

View File

@ -1,14 +1,15 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import org.objectweb.asm.Type
import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.buildExpectationStack
import scientifik.kmath.asm.internal.buildName import scientifik.kmath.asm.internal.buildName
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.* import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MSTExpression
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -18,6 +19,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> { is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) val constant = if (algebra is NumericAlgebra<T>)
algebra.number(node.value) algebra.number(node.value)
@ -26,48 +28,49 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
loadTConstant(constant) loadTConstant(constant)
} }
is MST.Unary -> { is MST.Unary -> {
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(algebra, node.operation, 1)) loadStringConstant(node.operation)
if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
visit(node.value) visit(node.value)
if (!tryInvokeSpecific(algebra, node.operation, 1)) { if (!tryInvokeSpecific(algebra, node.operation, 1)) invokeAlgebraOperation(
invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_TYPE.internalName,
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation", method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" + descriptor = Type.getMethodDescriptor(
"L${AsmBuilder.OBJECT_CLASS};" AsmBuilder.OBJECT_TYPE,
AsmBuilder.STRING_TYPE,
AsmBuilder.OBJECT_TYPE
),
tArity = 1
) )
} }
}
is MST.Binary -> { is MST.Binary -> {
loadAlgebra() loadAlgebra()
if (!buildExpectationStack(algebra, node.operation, 2)) loadStringConstant(node.operation)
if (!hasSpecific(algebra, node.operation, 2))
loadStringConstant(node.operation)
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) { if (!tryInvokeSpecific(algebra, node.operation, 2)) invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_TYPE.internalName,
invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation", method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" + descriptor = Type.getMethodDescriptor(
"L${AsmBuilder.OBJECT_CLASS};)" + AsmBuilder.OBJECT_TYPE,
"L${AsmBuilder.OBJECT_CLASS};" AsmBuilder.STRING_TYPE,
AsmBuilder.OBJECT_TYPE,
AsmBuilder.OBJECT_TYPE
),
tArity = 2
) )
} }
} }
} }
}
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
} }
/** /**
@ -79,7 +82,3 @@ inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> = ms
* Optimize performance of an [MSTExpression] using ASM codegen * Optimize performance of an [MSTExpression] using ASM codegen
*/ */
inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra) inline fun <reified T : Any> MSTExpression<T>.compile(): Expression<T> = mst.compileWith(T::class, algebra)
fun main() {
RealField.mstInField { symbol("x") + 2 }.compile()
}

View File

@ -1,25 +1,25 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.ClassWriter import org.objectweb.asm.*
import org.objectweb.asm.Label import org.objectweb.asm.Opcodes.AALOAD
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.RETURN
import org.objectweb.asm.Opcodes import org.objectweb.asm.commons.InstructionAdapter
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.ast.MST
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import java.util.*
import kotlin.reflect.KClass
/** /**
* ASM Builder is a structure that abstracts building a class for unwrapping [MST] to plain Java expression. * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression.
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
* *
* @param T the type generated [AsmCompiledExpression] operates. * @param T the type of AsmExpression to unwrap.
* @property classOfT the [Class] of T. * @param algebra the algebra the applied AsmExpressions use.
* @property algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class.
* @property className the unique class name of new loaded class.
* @property invokeLabel0Visitor the function applied to this object when the L0 label of invoke implementation is
* being written.
*/ */
internal class AsmBuilder<T> internal constructor( internal class AsmBuilder<T> internal constructor(
private val classOfT: Class<*>, private val classOfT: KClass<*>,
private val algebra: Algebra<T>, private val algebra: Algebra<T>,
private val className: String, private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
@ -34,16 +34,16 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* The instance of [ClassLoader] used by this builder. * The instance of [ClassLoader] used by this builder.
*/ */
private val classLoader: ClassLoader = private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
ClassLoader(javaClass.classLoader)
@Suppress("PrivatePropertyName") @Suppress("PrivatePropertyName")
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') private val T_ALGEBRA_TYPE: Type = algebra::class.asm
@Suppress("PrivatePropertyName") @Suppress("PrivatePropertyName")
private val T_CLASS: String = classOfT.name.replace('.', '/') internal val T_TYPE: Type = classOfT.asm
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') @Suppress("PrivatePropertyName")
private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
/** /**
* Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass.
@ -63,7 +63,16 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Method visitor of `invoke` method of [AsmCompiledExpression] subclass. * Method visitor of `invoke` method of [AsmCompiledExpression] subclass.
*/ */
private lateinit var invokeMethodVisitor: MethodVisitor private lateinit var invokeMethodVisitor: InstructionAdapter
internal var primitiveMode = false
@Suppress("PropertyName")
internal var PRIMITIVE_MASK: Type = OBJECT_TYPE
@Suppress("PropertyName")
internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE
private val typeStack = Stack<Type>()
internal val expectationStack = Stack<Type>().apply { push(T_TYPE) }
/** /**
* The cache of [AsmCompiledExpression] subclass built by this builder. * The cache of [AsmCompiledExpression] subclass built by this builder.
@ -76,81 +85,88 @@ internal class AsmBuilder<T> internal constructor(
* The built instance is cached. * The built instance is cached.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
internal fun getInstance(): AsmCompiledExpression<T> { fun getInstance(): AsmCompiledExpression<T> {
generatedInstance?.let { return it } generatedInstance?.let { return it }
if (SIGNATURE_LETTERS.containsKey(classOfT.java)) {
primitiveMode = true
PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java)
PRIMITIVE_MASK_BOXED = T_TYPE
}
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit( visit(
Opcodes.V1_8, Opcodes.V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
slashesClassName, CLASS_TYPE.internalName,
"L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;", "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;",
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, ASM_COMPILED_EXPRESSION_TYPE.internalName,
arrayOf() arrayOf()
) )
visitMethod( visitMethod(
access = Opcodes.ACC_PUBLIC, Opcodes.ACC_PUBLIC,
name = "<init>", "<init>",
descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE),
signature = null, null,
exceptions = null null
) { ).instructionAdapter {
val thisVar = 0 val thisVar = 0
val algebraVar = 1 val algebraVar = 1
val constantsVar = 2 val constantsVar = 2
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
visitLoadObjectVar(thisVar) load(thisVar, CLASS_TYPE)
visitLoadObjectVar(algebraVar) load(algebraVar, ALGEBRA_TYPE)
visitLoadObjectVar(constantsVar) load(constantsVar, OBJECT_ARRAY_TYPE)
visitInvokeSpecial( invokespecial(
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, ASM_COMPILED_EXPRESSION_TYPE.internalName,
"<init>", "<init>",
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V" Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE),
false
) )
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitReturn() visitInsn(RETURN)
val l2 = Label() val l2 = Label()
visitLabel(l2) visitLabel(l2)
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar)
visitLocalVariable( visitLocalVariable(
"algebra", "algebra",
"L$ALGEBRA_CLASS;", ALGEBRA_TYPE.descriptor,
"L$ALGEBRA_CLASS<L$T_CLASS;>;", "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;",
l0, l0,
l2, l2,
algebraVar algebraVar
) )
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar)
visitMaxs(0, 3) visitMaxs(0, 3)
visitEnd() visitEnd()
} }
visitMethod( visitMethod(
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
name = "invoke", "invoke",
descriptor = "(L$MAP_CLASS;)L$T_CLASS;", Type.getMethodDescriptor(T_TYPE, MAP_TYPE),
signature = "(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;", "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}",
exceptions = null null
) { ).instructionAdapter {
invokeMethodVisitor = this invokeMethodVisitor = this
visitCode() visitCode()
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
invokeLabel0Visitor() invokeLabel0Visitor()
visitReturnObject() areturn(T_TYPE)
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitLocalVariable( visitLocalVariable(
"this", "this",
"L$slashesClassName;", CLASS_TYPE.descriptor,
null, null,
l0, l0,
l1, l1,
@ -159,8 +175,8 @@ internal class AsmBuilder<T> internal constructor(
visitLocalVariable( visitLocalVariable(
"arguments", "arguments",
"L$MAP_CLASS;", MAP_TYPE.descriptor,
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;", "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;",
l0, l0,
l1, l1,
invokeArgumentsVar invokeArgumentsVar
@ -171,28 +187,28 @@ internal class AsmBuilder<T> internal constructor(
} }
visitMethod( visitMethod(
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
name = "invoke", "invoke",
descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;", Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
signature = null, null,
exceptions = null null
) { ).instructionAdapter {
val thisVar = 0 val thisVar = 0
val argumentsVar = 1 val argumentsVar = 1
visitCode() visitCode()
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
visitLoadObjectVar(thisVar) load(thisVar, OBJECT_TYPE)
visitLoadObjectVar(argumentsVar) load(argumentsVar, MAP_TYPE)
visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;") invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false)
visitReturnObject() areturn(T_TYPE)
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitLocalVariable( visitLocalVariable(
"this", "this",
"L$slashesClassName;", CLASS_TYPE.descriptor,
T_CLASS, null,
l0, l0,
l1, l1,
thisVar thisVar
@ -219,89 +235,111 @@ internal class AsmBuilder<T> internal constructor(
* Loads a constant from * Loads a constant from
*/ */
internal fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT.java in INLINABLE_NUMBERS) {
loadNumberConstant(value as Number) val expectedType = expectationStack.pop()!!
invokeMethodVisitor.visitCheckCast(T_CLASS) val mustBeBoxed = expectedType.sort == Type.OBJECT
loadNumberConstant(value as Number, mustBeBoxed)
if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK)
return return
} }
loadConstant(value as Any, T_CLASS) loadConstant(value as Any, T_TYPE)
} }
/** private fun box(): Unit = invokeMethodVisitor.invokestatic(
* Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. T_TYPE.internalName,
*/ "valueOf",
private fun loadConstant(value: Any, type: String) { Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK),
false
)
private fun unbox(): Unit = invokeMethodVisitor.invokevirtual(
NUMBER_TYPE.internalName,
NUMBER_CONVERTER_METHODS.getValue(PRIMITIVE_MASK),
Type.getMethodDescriptor(PRIMITIVE_MASK),
false
)
private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
invokeMethodVisitor.run {
loadThis() loadThis()
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
visitLdcOrIntConstant(idx) iconst(idx)
visitGetObjectArrayElement() visitInsn(AALOAD)
invokeMethodVisitor.visitCheckCast(type) checkcast(type)
}
} }
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE)
/** /**
* Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive
* constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded
* from it). * from it).
*/ */
private fun loadNumberConstant(value: Number) { private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) {
val clazz = value.javaClass val boxed = value::class.asm
val c = clazz.name.replace('.', '/') val primitive = BOXED_TO_PRIMITIVES[boxed]
val sigLetter = SIGNATURE_LETTERS[clazz]
if (sigLetter != null) { if (primitive != null) {
when (value) { when (primitive) {
is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat())
is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong())
is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt())
else -> invokeMethodVisitor.visitLdcInsn(value) }
if (mustBeBoxed) {
box()
invokeMethodVisitor.checkcast(T_TYPE)
} }
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
return return
} }
loadConstant(value, c) loadConstant(value, boxed)
if (!mustBeBoxed) unbox()
else invokeMethodVisitor.checkcast(T_TYPE)
} }
/** /**
* Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided.
*/ */
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
visitLoadObjectVar(invokeArgumentsVar) load(invokeArgumentsVar, OBJECT_ARRAY_TYPE)
if (defaultValue != null) { if (defaultValue != null) {
visitLdcInsn(name) loadStringConstant(name)
loadTConstant(defaultValue) loadTConstant(defaultValue)
visitInvokeInterface( invokeinterface(
owner = MAP_CLASS, MAP_TYPE.internalName,
name = "getOrDefault", "getOrDefault",
descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;" Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE)
) )
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.checkcast(T_TYPE)
return return
} }
visitLdcInsn(name) loadStringConstant(name)
visitInvokeInterface( invokeinterface(
owner = MAP_CLASS, MAP_TYPE.internalName,
name = "get", "get",
descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;" Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE)
) )
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.checkcast(T_TYPE)
val expectedType = expectationStack.pop()!!
if (expectedType.sort == Type.OBJECT)
typeStack.push(T_TYPE)
else {
unbox()
typeStack.push(PRIMITIVE_MASK)
}
} }
/** /**
@ -310,13 +348,10 @@ internal class AsmBuilder<T> internal constructor(
internal fun loadAlgebra() { internal fun loadAlgebra() {
loadThis() loadThis()
invokeMethodVisitor.visitGetField( invokeMethodVisitor.run {
owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS, getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor)
name = "algebra", checkcast(T_ALGEBRA_TYPE)
descriptor = "L$ALGEBRA_CLASS;" }
)
invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS)
} }
/** /**
@ -330,29 +365,70 @@ internal class AsmBuilder<T> internal constructor(
owner: String, owner: String,
method: String, method: String,
descriptor: String, descriptor: String,
tArity: Int,
opcode: Int = Opcodes.INVOKEINTERFACE opcode: Int = Opcodes.INVOKEINTERFACE
) { ) {
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, opcode == Opcodes.INVOKEINTERFACE) repeat(tArity) { if (!typeStack.empty()) typeStack.pop() }
invokeMethodVisitor.visitCheckCast(T_CLASS)
invokeMethodVisitor.visitMethodInsn(
opcode,
owner,
method,
descriptor,
opcode == Opcodes.INVOKEINTERFACE
)
invokeMethodVisitor.checkcast(T_TYPE)
val isLastExpr = expectationStack.size == 1
val expectedType = expectationStack.pop()!!
if (expectedType.sort == Type.OBJECT || isLastExpr)
typeStack.push(T_TYPE)
else {
unbox()
typeStack.push(PRIMITIVE_MASK)
}
} }
/** /**
* Writes a LDC Instruction with string constant provided. * Writes a LDC Instruction with string constant provided.
*/ */
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string)
internal companion object { internal companion object {
/** /**
* Maps JVM primitive numbers boxed types to their letters of JVM signature convention. * Maps JVM primitive numbers boxed types to their letters of JVM signature convention.
*/ */
private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy { private val SIGNATURE_LETTERS: Map<Class<out Any>, Type> by lazy {
hashMapOf( hashMapOf(
java.lang.Byte::class.java to "B", java.lang.Byte::class.java to Type.BYTE_TYPE,
java.lang.Short::class.java to "S", java.lang.Short::class.java to Type.SHORT_TYPE,
java.lang.Integer::class.java to "I", java.lang.Integer::class.java to Type.INT_TYPE,
java.lang.Long::class.java to "J", java.lang.Long::class.java to Type.LONG_TYPE,
java.lang.Float::class.java to "F", java.lang.Float::class.java to Type.FLOAT_TYPE,
java.lang.Double::class.java to "D" java.lang.Double::class.java to Type.DOUBLE_TYPE
)
}
private val BOXED_TO_PRIMITIVES: Map<Type, Type> by lazy {
hashMapOf(
java.lang.Byte::class.asm to Type.BYTE_TYPE,
java.lang.Short::class.asm to Type.SHORT_TYPE,
java.lang.Integer::class.asm to Type.INT_TYPE,
java.lang.Long::class.asm to Type.LONG_TYPE,
java.lang.Float::class.asm to Type.FLOAT_TYPE,
java.lang.Double::class.asm to Type.DOUBLE_TYPE
)
}
private val NUMBER_CONVERTER_METHODS: Map<Type, String> by lazy {
hashMapOf(
Type.BYTE_TYPE to "byteValue",
Type.SHORT_TYPE to "shortValue",
Type.INT_TYPE to "intValue",
Type.LONG_TYPE to "longValue",
Type.FLOAT_TYPE to "floatValue",
Type.DOUBLE_TYPE to "doubleValue"
) )
} }
@ -360,13 +436,14 @@ internal class AsmBuilder<T> internal constructor(
* Provides boxed number types values of which can be stored in JVM bytecode constant pool. * Provides boxed number types values of which can be stored in JVM bytecode constant pool.
*/ */
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys } private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm
internal val NUMBER_TYPE: Type = java.lang.Number::class.asm
internal val MAP_TYPE: Type = java.util.Map::class.asm
internal val OBJECT_TYPE: Type = java.lang.Object::class.asm
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName")
"scientifik/kmath/asm/internal/AsmCompiledExpression" internal val OBJECT_ARRAY_TYPE: Type = Array<java.lang.Object>::class.asm
internal val ALGEBRA_TYPE: Type = Algebra::class.asm
internal const val MAP_CLASS = "java/util/Map" internal val STRING_TYPE: Type = java.lang.String::class.asm
internal const val OBJECT_CLASS = "java/lang/Object"
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
internal const val STRING_CLASS = "java/lang/String"
} }
} }

View File

@ -16,4 +16,3 @@ internal abstract class AsmCompiledExpression<T> internal constructor(
) : Expression<T> { ) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T abstract override fun invoke(arguments: Map<String, T>): T
} }

View File

@ -0,0 +1,7 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Type
import kotlin.reflect.KClass
internal val KClass<*>.asm: Type
get() = Type.getType(java)

View File

@ -1,60 +1,9 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter
internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) { fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this)
-1 -> visitInsn(ICONST_M1)
0 -> visitInsn(ICONST_0)
1 -> visitInsn(ICONST_1)
2 -> visitInsn(ICONST_2)
3 -> visitInsn(ICONST_3)
4 -> visitInsn(ICONST_4)
5 -> visitInsn(ICONST_5)
in -128..127 -> visitIntInsn(BIPUSH, value)
in -32768..32767 -> visitIntInsn(SIPUSH, value)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) { fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter =
0.0 -> visitInsn(DCONST_0) instructionAdapter().apply(block)
1.0 -> visitInsn(DCONST_1)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) {
0L -> visitInsn(LCONST_0)
1L -> visitInsn(LCONST_1)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) {
0f -> visitInsn(FCONST_0)
1f -> visitInsn(FCONST_1)
2f -> visitInsn(FCONST_2)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true)
internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false)
internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false)
internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false)
internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type)
internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit =
visitFieldInsn(GETFIELD, owner, name, descriptor)
internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`)
internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD)
internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN)
internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN)

View File

@ -1,6 +1,7 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import org.objectweb.asm.Type
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> by lazy { private val methodNameAdapters: Map<String, String> by lazy {
@ -12,17 +13,19 @@ private val methodNameAdapters: Map<String, String> by lazy {
} }
/** /**
* Checks if the target [context] for code generation contains a method with needed [name] and [arity]. * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds
* type expectation stack for needed arity.
* *
* @return `true` if contains, else `false`. * @return `true` if contains, else `false`.
*/ */
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.buildExpectationStack(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name val aName = methodNameAdapters[name] ?: name
context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null
?: return false val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE
repeat(arity) { expectationStack.push(t) }
return true return hasSpecific
} }
/** /**
@ -34,22 +37,23 @@ internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boo
internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean { internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name val aName = methodNameAdapters[name] ?: name
context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } val method =
?: return false context.javaClass.methods.find {
var suitableSignature = it.name == aName && it.parameters.size == arity
if (primitiveMode && it.isBridge)
suitableSignature = false
suitableSignature
} ?: return false
val owner = context::class.java.name.replace('.', '/') val owner = context::class.java.name.replace('.', '/')
val sig = buildString {
append('(')
repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") }
append(')')
append("L${AsmBuilder.OBJECT_CLASS};")
}
invokeAlgebraOperation( invokeAlgebraOperation(
owner = owner, owner = owner,
method = aName, method = aName,
descriptor = sig, descriptor = Type.getMethodDescriptor(PRIMITIVE_MASK_BOXED, *Array(arity) { PRIMITIVE_MASK }),
tArity = arity,
opcode = Opcodes.INVOKEVIRTUAL opcode = Opcodes.INVOKEVIRTUAL
) )

View File

@ -1,66 +1,110 @@
package scietifik.kmath.asm package scietifik.kmath.asm
// import scientifik.kmath.asm.compile
//class TestAsmAlgebras { import scientifik.kmath.ast.mstInField
// @Test import scientifik.kmath.ast.mstInRing
// fun space() { import scientifik.kmath.ast.mstInSpace
// val res = ByteRing.asmInRing { import scientifik.kmath.expressions.invoke
// binaryOperation( import scientifik.kmath.operations.ByteRing
// "+", import scientifik.kmath.operations.RealField
// import kotlin.test.Test
// unaryOperation( import kotlin.test.assertEquals
// "+",
// 3.toByte() - (2.toByte() + (multiply( class TestAsmAlgebras {
// add(number(1), number(1)), @Test
// 2 fun space() {
// ) + 1.toByte()) * 3.toByte() - 1.toByte()) val res1 = ByteRing.mstInSpace {
// ), binaryOperation(
// "+",
// number(1)
// ) + symbol("x") + zero unaryOperation(
// }("x" to 2.toByte()) "+",
// number(3.toByte()) - (number(2.toByte()) + (multiply(
// assertEquals(16, res) add(number(1), number(1)),
// } 2
// ) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
// @Test ),
// fun ring() {
// val res = ByteRing.asmInRing { number(1)
// binaryOperation( ) + symbol("x") + zero
// "+", }("x" to 2.toByte())
//
// unaryOperation( val res2 = ByteRing.mstInSpace {
// "+", binaryOperation(
// (3.toByte() - (2.toByte() + (multiply( "+",
// add(const(1), const(1)),
// 2 unaryOperation(
// ) + 1.toByte()))) * 3.0 - 1.toByte() "+",
// ), number(3.toByte()) - (number(2.toByte()) + (multiply(
// add(number(1), number(1)),
// number(1) 2
// ) * const(2) ) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
// }() ),
//
// assertEquals(24, res) number(1)
// } ) + symbol("x") + zero
// }.compile()("x" to 2.toByte())
// @Test
// fun field() { assertEquals(res1, res2)
// val res = RealField.asmInField { }
// +(3 - 2 + 2*(number(1)+1.0)
// @Test
// unaryOperation( fun ring() {
// "+", val res1 = ByteRing.mstInRing {
// (3.0 - (2.0 + (multiply( binaryOperation(
// add((1.0), const(1.0)), "+",
// 2
// ) + 1.0))) * 3 - 1.0 unaryOperation(
// )+ "+",
// (symbol("x") - (2.toByte() + (multiply(
// number(1) add(number(1), number(1)),
// ) / 2, const(2.0)) * one 2
// }() ) + 1.toByte()))) * 3.0 - 1.toByte()
// ),
// assertEquals(3.0, res)
// } number(1)
//} ) * number(2)
}("x" to 3.toByte())
val res2 = ByteRing.mstInRing {
binaryOperation(
"+",
unaryOperation(
"+",
(symbol("x") - (2.toByte() + (multiply(
add(number(1), number(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
}
@Test
fun field() {
val res1 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
"+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1),
1 / 2 + number(2.0) * one
)
}("x" to 2.0)
val res2 = RealField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperation(
"+",
(3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0
+ number(1),
1 / 2 + number(2.0) * one
)
}.compile()("x" to 2.0)
assertEquals(res1, res2)
}
}

View File

@ -16,6 +16,13 @@ class TestAsmExpressions {
assertEquals(-2.0, res) assertEquals(-2.0, res)
} }
@Test
fun testBinaryOperationInvocation() {
val expression = RealField.mstInSpace { -symbol("x") + number(1.0) }.compile()
val res = expression("x" to 2.0)
assertEquals(-1.0, res)
}
@Test @Test
fun testConstProductInvocation() { fun testConstProductInvocation() {
val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0)

View File

@ -1,7 +1,10 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate import scientifik.kmath.asm.compile
import scientifik.kmath.asm.expression
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import kotlin.test.Test import kotlin.test.Test
@ -9,9 +12,15 @@ import kotlin.test.assertEquals
class AsmTest { class AsmTest {
@Test @Test
fun parsedExpression() { fun `compile MST`() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst) val res = ComplexField.expression(mst)()
assertEquals(Complex(10.0, 0.0), res)
}
@Test
fun `compile MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }.compile()()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
} }

View File

@ -1,17 +1,25 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserTest { internal class ParserTest {
@Test @Test
fun parsedExpression() { fun `evaluate MST`() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst) val res = ComplexField.evaluate(mst)
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
@Test
fun `evaluate MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res)
}
} }