forked from kscience/kmath
Merge pull request #105 from CommanderTvis/adv-expr-refactor-agc
Refactor ASM generation code
This commit is contained in:
commit
7c769bf74a
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.internal.AsmBuilder
|
||||
import scientifik.kmath.asm.internal.buildName
|
||||
import scientifik.kmath.asm.internal.hasSpecific
|
||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||
import scientifik.kmath.ast.MST
|
||||
@ -11,44 +12,30 @@ import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.*
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
|
||||
/**
|
||||
* Compile given MST to an Expression using AST compiler
|
||||
*/
|
||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||
|
||||
fun buildName(mst: MST, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(mst, collision + 1)
|
||||
}
|
||||
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit {
|
||||
fun AsmBuilder<T>.visit(node: MST) {
|
||||
when (node) {
|
||||
is MST.Symbolic -> visitLoadFromVariables(node.value)
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
is MST.Numeric -> {
|
||||
val constant = if (algebra is NumericAlgebra<T>) {
|
||||
algebra.number(node.value)
|
||||
} else {
|
||||
error("Number literals are not supported in $algebra")
|
||||
}
|
||||
visitLoadFromConstants(constant)
|
||||
loadTConstant(constant)
|
||||
}
|
||||
is MST.Unary -> {
|
||||
visitLoadAlgebra()
|
||||
loadAlgebra()
|
||||
|
||||
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
|
||||
if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
|
||||
|
||||
visit(node.value)
|
||||
|
||||
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
|
||||
visitAlgebraOperation(
|
||||
invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "unaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
@ -58,17 +45,17 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
}
|
||||
}
|
||||
is MST.Binary -> {
|
||||
visitLoadAlgebra()
|
||||
loadAlgebra()
|
||||
|
||||
if (!hasSpecific(algebra, node.operation, 2))
|
||||
visitStringConstant(node.operation)
|
||||
loadStringConstant(node.operation)
|
||||
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
|
||||
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
|
||||
|
||||
visitAlgebraOperation(
|
||||
invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "binaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
@ -81,9 +68,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
}
|
||||
}
|
||||
|
||||
val builder = AsmBuilder(type.java, algebra, buildName(this))
|
||||
builder.visit(this)
|
||||
return builder.generate()
|
||||
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||
|
@ -4,41 +4,29 @@ import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.Label
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes
|
||||
import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
|
||||
private abstract class FunctionalCompiledExpression<T> internal constructor(
|
||||
@JvmField protected val algebra: Algebra<T>,
|
||||
@JvmField protected val constants: Array<Any>
|
||||
) : Expression<T> {
|
||||
abstract override fun invoke(arguments: Map<String, T>): T
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java
|
||||
* expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new
|
||||
* class.
|
||||
* ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression.
|
||||
* This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
|
||||
*
|
||||
* @param T the type of AsmExpression to unwrap.
|
||||
* @param algebra the algebra the applied AsmExpressions use.
|
||||
* @param className the unique class name of new loaded class.
|
||||
*
|
||||
* @author [Iaroslav Postovalov](https://github.com/CommanderTvis)
|
||||
*/
|
||||
internal class AsmBuilder<T>(
|
||||
internal class AsmBuilder<T> internal constructor(
|
||||
private val classOfT: Class<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
private val className: String
|
||||
private val className: String,
|
||||
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
|
||||
) {
|
||||
private class AsmClassLoader(parent: ClassLoader) : ClassLoader(parent) {
|
||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||
}
|
||||
|
||||
private val classLoader: AsmClassLoader =
|
||||
AsmClassLoader(javaClass.classLoader)
|
||||
private val classLoader: ClassLoader =
|
||||
ClassLoader(javaClass.classLoader)
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
|
||||
@ -49,15 +37,16 @@ internal class AsmBuilder<T>(
|
||||
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
||||
private val invokeThisVar: Int = 0
|
||||
private val invokeArgumentsVar: Int = 1
|
||||
private var maxStack: Int = 0
|
||||
private val constants: MutableList<Any> = mutableListOf()
|
||||
private val asmCompiledClassWriter: ClassWriter = ClassWriter(0)
|
||||
private val invokeMethodVisitor: MethodVisitor
|
||||
private val invokeL0: Label
|
||||
private lateinit var invokeL1: Label
|
||||
private lateinit var invokeMethodVisitor: MethodVisitor
|
||||
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||
|
||||
init {
|
||||
asmCompiledClassWriter.visit(
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun getInstance(): AsmCompiledExpression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||
visit(
|
||||
Opcodes.V1_8,
|
||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
||||
slashesClassName,
|
||||
@ -66,28 +55,31 @@ internal class AsmBuilder<T>(
|
||||
arrayOf()
|
||||
)
|
||||
|
||||
asmCompiledClassWriter.run {
|
||||
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run {
|
||||
visitMethod(
|
||||
access = Opcodes.ACC_PUBLIC,
|
||||
name = "<init>",
|
||||
descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
|
||||
signature = null,
|
||||
exceptions = null
|
||||
) {
|
||||
val thisVar = 0
|
||||
val algebraVar = 1
|
||||
val constantsVar = 2
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(Opcodes.ALOAD, thisVar)
|
||||
visitVarInsn(Opcodes.ALOAD, algebraVar)
|
||||
visitVarInsn(Opcodes.ALOAD, constantsVar)
|
||||
visitLoadObjectVar(thisVar)
|
||||
visitLoadObjectVar(algebraVar)
|
||||
visitLoadObjectVar(constantsVar)
|
||||
|
||||
visitMethodInsn(
|
||||
Opcodes.INVOKESPECIAL,
|
||||
visitInvokeSpecial(
|
||||
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||
"<init>",
|
||||
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
|
||||
false
|
||||
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V"
|
||||
)
|
||||
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
visitInsn(Opcodes.RETURN)
|
||||
visitReturn()
|
||||
val l2 = Label()
|
||||
visitLabel(l2)
|
||||
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
|
||||
@ -102,40 +94,32 @@ internal class AsmBuilder<T>(
|
||||
)
|
||||
|
||||
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar)
|
||||
visitMaxs(3, 3)
|
||||
visitMaxs(0, 3)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
invokeMethodVisitor = visitMethod(
|
||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
||||
"invoke",
|
||||
"(L$MAP_CLASS;)L$T_CLASS;",
|
||||
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||
null
|
||||
)
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitMethod(
|
||||
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
||||
name = "invoke",
|
||||
descriptor = "(L$MAP_CLASS;)L$T_CLASS;",
|
||||
signature = "(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||
exceptions = null
|
||||
) {
|
||||
invokeMethodVisitor = this
|
||||
visitCode()
|
||||
invokeL0 = Label()
|
||||
visitLabel(invokeL0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun generate(): Expression<T> {
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitInsn(Opcodes.ARETURN)
|
||||
invokeL1 = Label()
|
||||
visitLabel(invokeL1)
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
invokeLabel0Visitor()
|
||||
visitReturnObject()
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
invokeL0,
|
||||
invokeL1,
|
||||
null,
|
||||
l0,
|
||||
l1,
|
||||
invokeThisVar
|
||||
)
|
||||
|
||||
@ -143,30 +127,31 @@ internal class AsmBuilder<T>(
|
||||
"arguments",
|
||||
"L$MAP_CLASS;",
|
||||
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
||||
invokeL0,
|
||||
invokeL1,
|
||||
l0,
|
||||
l1,
|
||||
invokeArgumentsVar
|
||||
)
|
||||
|
||||
visitMaxs(maxStack + 1, 2)
|
||||
visitMaxs(0, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitMethod(
|
||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
||||
"invoke",
|
||||
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||
null,
|
||||
null
|
||||
).run {
|
||||
visitMethod(
|
||||
access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
||||
name = "invoke",
|
||||
descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||
signature = null,
|
||||
exceptions = null
|
||||
) {
|
||||
val thisVar = 0
|
||||
val argumentsVar = 1
|
||||
visitCode()
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(Opcodes.ALOAD, 0)
|
||||
visitVarInsn(Opcodes.ALOAD, 1)
|
||||
visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
||||
visitInsn(Opcodes.ARETURN)
|
||||
visitLoadObjectVar(thisVar)
|
||||
visitLoadObjectVar(argumentsVar)
|
||||
visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;")
|
||||
visitReturnObject()
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
|
||||
@ -179,127 +164,123 @@ internal class AsmBuilder<T>(
|
||||
thisVar
|
||||
)
|
||||
|
||||
visitMaxs(2, 2)
|
||||
visitMaxs(0, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitEnd()
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
val new = classLoader
|
||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||
.defineClass(className, classWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
|
||||
.newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression<T>
|
||||
|
||||
generatedInstance = new
|
||||
return new
|
||||
}
|
||||
|
||||
fun visitLoadFromConstants(value: T) {
|
||||
internal fun loadTConstant(value: T) {
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
visitNumberConstant(value as Number)
|
||||
visitCastToT()
|
||||
loadNumberConstant(value as Number)
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
return
|
||||
}
|
||||
|
||||
visitLoadAnyFromConstants(value as Any, T_CLASS)
|
||||
loadConstant(value as Any, T_CLASS)
|
||||
}
|
||||
|
||||
private fun visitLoadAnyFromConstants(value: Any, type: String) {
|
||||
private fun loadConstant(value: Any, type: String) {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
maxStack++
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitLoadThis()
|
||||
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;")
|
||||
visitLdcOrIConstInsn(idx)
|
||||
visitInsn(Opcodes.AALOAD)
|
||||
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type)
|
||||
loadThis()
|
||||
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;")
|
||||
visitLdcOrIntConstant(idx)
|
||||
visitGetObjectArrayElement()
|
||||
invokeMethodVisitor.visitCheckCast(type)
|
||||
}
|
||||
}
|
||||
|
||||
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
|
||||
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
|
||||
|
||||
fun visitNumberConstant(value: Number) {
|
||||
maxStack++
|
||||
internal fun loadNumberConstant(value: Number) {
|
||||
val clazz = value.javaClass
|
||||
val c = clazz.name.replace('.', '/')
|
||||
val sigLetter = SIGNATURE_LETTERS[clazz]
|
||||
|
||||
if (sigLetter != null) {
|
||||
when (value) {
|
||||
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value)
|
||||
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value)
|
||||
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value)
|
||||
is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value)
|
||||
is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value)
|
||||
is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value)
|
||||
else -> invokeMethodVisitor.visitLdcInsn(value)
|
||||
}
|
||||
|
||||
invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
|
||||
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
|
||||
return
|
||||
}
|
||||
|
||||
visitLoadAnyFromConstants(value, c)
|
||||
loadConstant(value, c)
|
||||
}
|
||||
|
||||
fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||
maxStack += 2
|
||||
visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar)
|
||||
internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||
visitLoadObjectVar(invokeArgumentsVar)
|
||||
|
||||
if (defaultValue != null) {
|
||||
visitLdcInsn(name)
|
||||
visitLoadFromConstants(defaultValue)
|
||||
loadTConstant(defaultValue)
|
||||
|
||||
visitMethodInsn(
|
||||
Opcodes.INVOKEINTERFACE,
|
||||
MAP_CLASS,
|
||||
"getOrDefault",
|
||||
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
|
||||
true
|
||||
visitInvokeInterface(
|
||||
owner = MAP_CLASS,
|
||||
name = "getOrDefault",
|
||||
descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;"
|
||||
)
|
||||
|
||||
visitCastToT()
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
return
|
||||
}
|
||||
|
||||
visitLdcInsn(name)
|
||||
visitMethodInsn(
|
||||
Opcodes.INVOKEINTERFACE,
|
||||
MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true
|
||||
)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
fun visitLoadAlgebra() {
|
||||
maxStack++
|
||||
invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
|
||||
|
||||
invokeMethodVisitor.visitFieldInsn(
|
||||
Opcodes.GETFIELD,
|
||||
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||
visitInvokeInterface(
|
||||
owner = MAP_CLASS,
|
||||
name = "get",
|
||||
descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;"
|
||||
)
|
||||
|
||||
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS)
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
}
|
||||
|
||||
fun visitAlgebraOperation(
|
||||
internal fun loadAlgebra() {
|
||||
loadThis()
|
||||
|
||||
invokeMethodVisitor.visitGetField(
|
||||
owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||
name = "algebra",
|
||||
descriptor = "L$ALGEBRA_CLASS;"
|
||||
)
|
||||
|
||||
invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS)
|
||||
}
|
||||
|
||||
internal fun invokeAlgebraOperation(
|
||||
owner: String,
|
||||
method: String,
|
||||
descriptor: String,
|
||||
opcode: Int = Opcodes.INVOKEINTERFACE,
|
||||
isInterface: Boolean = true
|
||||
) {
|
||||
maxStack++
|
||||
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
|
||||
visitCastToT()
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
}
|
||||
|
||||
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS)
|
||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
|
||||
|
||||
fun visitStringConstant(string: String) {
|
||||
invokeMethodVisitor.visitLdcInsn(string)
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val SIGNATURE_LETTERS = mapOf(
|
||||
internal companion object {
|
||||
private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy {
|
||||
mapOf(
|
||||
java.lang.Byte::class.java to "B",
|
||||
java.lang.Short::class.java to "S",
|
||||
java.lang.Integer::class.java to "I",
|
||||
@ -307,15 +288,16 @@ internal class AsmBuilder<T>(
|
||||
java.lang.Float::class.java to "F",
|
||||
java.lang.Double::class.java to "D"
|
||||
)
|
||||
}
|
||||
|
||||
private val INLINABLE_NUMBERS = SIGNATURE_LETTERS.keys
|
||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||
|
||||
const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/FunctionalCompiledExpression"
|
||||
const val MAP_CLASS = "java/util/Map"
|
||||
const val OBJECT_CLASS = "java/lang/Object"
|
||||
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||
const val STRING_CLASS = "java/lang/String"
|
||||
const val NUMBER_CLASS = "java/lang/Number"
|
||||
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
||||
"scientifik/kmath/asm/internal/AsmCompiledExpression"
|
||||
|
||||
internal const val MAP_CLASS = "java/util/Map"
|
||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
internal const val STRING_CLASS = "java/lang/String"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,12 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
internal abstract class AsmCompiledExpression<T> internal constructor(
|
||||
@JvmField protected val algebra: Algebra<T>,
|
||||
@JvmField protected val constants: Array<Any>
|
||||
) : Expression<T> {
|
||||
abstract override fun invoke(arguments: Map<String, T>): T
|
||||
}
|
||||
|
@ -0,0 +1,15 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import scientifik.kmath.ast.MST
|
||||
|
||||
internal fun buildName(mst: MST, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(mst, collision + 1)
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
|
||||
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block)
|
||||
|
||||
internal inline fun ClassWriter.visitMethod(
|
||||
access: Int,
|
||||
name: String,
|
||||
descriptor: String,
|
||||
signature: String?,
|
||||
exceptions: Array<String>?,
|
||||
block: MethodVisitor.() -> Unit
|
||||
): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block)
|
@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes.*
|
||||
|
||||
internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) {
|
||||
internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) {
|
||||
-1 -> visitInsn(ICONST_M1)
|
||||
0 -> visitInsn(ICONST_0)
|
||||
1 -> visitInsn(ICONST_1)
|
||||
@ -14,15 +14,39 @@ internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) {
|
||||
else -> visitLdcInsn(value)
|
||||
}
|
||||
|
||||
internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) {
|
||||
internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) {
|
||||
0.0 -> visitInsn(DCONST_0)
|
||||
1.0 -> visitInsn(DCONST_1)
|
||||
else -> visitLdcInsn(value)
|
||||
}
|
||||
|
||||
internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) {
|
||||
internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) {
|
||||
0f -> visitInsn(FCONST_0)
|
||||
1f -> visitInsn(FCONST_1)
|
||||
2f -> visitInsn(FCONST_2)
|
||||
else -> visitLdcInsn(value)
|
||||
}
|
||||
|
||||
internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit =
|
||||
visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true)
|
||||
|
||||
internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit =
|
||||
visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false)
|
||||
|
||||
internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit =
|
||||
visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false)
|
||||
|
||||
internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit =
|
||||
visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false)
|
||||
|
||||
internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type)
|
||||
|
||||
internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit =
|
||||
visitFieldInsn(GETFIELD, owner, name, descriptor)
|
||||
|
||||
internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`)
|
||||
|
||||
internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD)
|
||||
|
||||
internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN)
|
||||
internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN)
|
||||
|
@ -29,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
append("L${AsmBuilder.OBJECT_CLASS};")
|
||||
}
|
||||
|
||||
visitAlgebraOperation(
|
||||
invokeAlgebraOperation(
|
||||
owner = owner,
|
||||
method = aName,
|
||||
descriptor = sig,
|
||||
@ -39,8 +39,3 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
|
||||
return true
|
||||
}
|
||||
//
|
||||
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||
// val a = tryEvaluate()
|
||||
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
|
||||
//}
|
||||
|
Loading…
Reference in New Issue
Block a user