Merge remote-tracking branch 'origin/adv-expr' into adv-expr

This commit is contained in:
Alexander Nozik 2020-06-16 10:33:16 +03:00
commit d73f564fb4
7 changed files with 249 additions and 221 deletions

View File

@ -1,6 +1,7 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.buildName
import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
@ -10,44 +11,30 @@ import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
* Compile given MST to an Expression using AST compiler * Compile given MST to an Expression using AST compiler
*/ */
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> { fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) {
fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}
fun AsmBuilder<T>.visit(node: MST): Unit {
when (node) { when (node) {
is MST.Symbolic -> visitLoadFromVariables(node.value) is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> { is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) { val constant = if (algebra is NumericAlgebra<T>) {
algebra.number(node.value) algebra.number(node.value)
} else { } else {
error("Number literals are not supported in $algebra") error("Number literals are not supported in $algebra")
} }
visitLoadFromConstants(constant) loadTConstant(constant)
} }
is MST.Unary -> { is MST.Unary -> {
visitLoadAlgebra() loadAlgebra()
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation) if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
visit(node.value) visit(node.value)
if (!tryInvokeSpecific(algebra, node.operation, 1)) { if (!tryInvokeSpecific(algebra, node.operation, 1)) {
visitAlgebraOperation( invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS, owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation", method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" +
@ -57,17 +44,17 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
} }
} }
is MST.Binary -> { is MST.Binary -> {
visitLoadAlgebra() loadAlgebra()
if (!hasSpecific(algebra, node.operation, 2)) if (!hasSpecific(algebra, node.operation, 2))
visitStringConstant(node.operation) loadStringConstant(node.operation)
visit(node.left) visit(node.left)
visit(node.right) visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) { if (!tryInvokeSpecific(algebra, node.operation, 2)) {
visitAlgebraOperation( invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS, owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation", method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" +
@ -80,9 +67,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
} }
} }
val builder = AsmBuilder(type.java, algebra, buildName(this)) return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
builder.visit(this)
return builder.generate()
} }
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this) inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)

View File

@ -4,41 +4,29 @@ import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Label import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/** /**
* AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression.
* expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class.
* class.
* *
* @param T the type of AsmExpression to unwrap. * @param T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use. * @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class. * @param className the unique class name of new loaded class.
*
* @author [Iaroslav Postovalov](https://github.com/CommanderTvis)
*/ */
internal class AsmBuilder<T>( internal class AsmBuilder<T> internal constructor(
private val classOfT: Class<*>, private val classOfT: Class<*>,
private val algebra: Algebra<T>, private val algebra: Algebra<T>,
private val className: String private val className: String,
private val invokeLabel0Visitor: AsmBuilder<T>.() -> Unit
) { ) {
private class AsmClassLoader(parent: ClassLoader) : ClassLoader(parent) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
} }
private val classLoader: AsmClassLoader = private val classLoader: ClassLoader =
AsmClassLoader(javaClass.classLoader) ClassLoader(javaClass.classLoader)
@Suppress("PrivatePropertyName") @Suppress("PrivatePropertyName")
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
@ -49,15 +37,16 @@ internal class AsmBuilder<T>(
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
private val invokeThisVar: Int = 0 private val invokeThisVar: Int = 0
private val invokeArgumentsVar: Int = 1 private val invokeArgumentsVar: Int = 1
private var maxStack: Int = 0
private val constants: MutableList<Any> = mutableListOf() private val constants: MutableList<Any> = mutableListOf()
private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) private lateinit var invokeMethodVisitor: MethodVisitor
private val invokeMethodVisitor: MethodVisitor private var generatedInstance: AsmCompiledExpression<T>? = null
private val invokeL0: Label
private lateinit var invokeL1: Label
init { @Suppress("UNCHECKED_CAST")
asmCompiledClassWriter.visit( fun getInstance(): AsmCompiledExpression<T> {
generatedInstance?.let { return it }
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
Opcodes.V1_8, Opcodes.V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
slashesClassName, slashesClassName,
@ -66,28 +55,31 @@ internal class AsmBuilder<T>(
arrayOf() arrayOf()
) )
asmCompiledClassWriter.run { visitMethod(
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run { access = Opcodes.ACC_PUBLIC,
name = "<init>",
descriptor = "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
signature = null,
exceptions = null
) {
val thisVar = 0 val thisVar = 0
val algebraVar = 1 val algebraVar = 1
val constantsVar = 2 val constantsVar = 2
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
visitVarInsn(Opcodes.ALOAD, thisVar) visitLoadObjectVar(thisVar)
visitVarInsn(Opcodes.ALOAD, algebraVar) visitLoadObjectVar(algebraVar)
visitVarInsn(Opcodes.ALOAD, constantsVar) visitLoadObjectVar(constantsVar)
visitMethodInsn( visitInvokeSpecial(
Opcodes.INVOKESPECIAL,
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
"<init>", "<init>",
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V"
false
) )
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
visitInsn(Opcodes.RETURN) visitReturn()
val l2 = Label() val l2 = Label()
visitLabel(l2) visitLabel(l2)
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
@ -102,40 +94,32 @@ internal class AsmBuilder<T>(
) )
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar)
visitMaxs(3, 3) visitMaxs(0, 3)
visitEnd() visitEnd()
} }
invokeMethodVisitor = visitMethod( visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
"invoke", name = "invoke",
"(L$MAP_CLASS;)L$T_CLASS;", descriptor = "(L$MAP_CLASS;)L$T_CLASS;",
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;", signature = "(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
null exceptions = null
) ) {
invokeMethodVisitor = this
invokeMethodVisitor.run {
visitCode() visitCode()
invokeL0 = Label() val l0 = Label()
visitLabel(invokeL0) visitLabel(l0)
} invokeLabel0Visitor()
} visitReturnObject()
} val l1 = Label()
visitLabel(l1)
@Suppress("UNCHECKED_CAST")
fun generate(): Expression<T> {
invokeMethodVisitor.run {
visitInsn(Opcodes.ARETURN)
invokeL1 = Label()
visitLabel(invokeL1)
visitLocalVariable( visitLocalVariable(
"this", "this",
"L$slashesClassName;", "L$slashesClassName;",
T_CLASS, null,
invokeL0, l0,
invokeL1, l1,
invokeThisVar invokeThisVar
) )
@ -143,30 +127,31 @@ internal class AsmBuilder<T>(
"arguments", "arguments",
"L$MAP_CLASS;", "L$MAP_CLASS;",
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;", "L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
invokeL0, l0,
invokeL1, l1,
invokeArgumentsVar invokeArgumentsVar
) )
visitMaxs(maxStack + 1, 2) visitMaxs(0, 2)
visitEnd() visitEnd()
} }
asmCompiledClassWriter.visitMethod( visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, access = Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
"invoke", name = "invoke",
"(L$MAP_CLASS;)L$OBJECT_CLASS;", descriptor = "(L$MAP_CLASS;)L$OBJECT_CLASS;",
null, signature = null,
null exceptions = null
).run { ) {
val thisVar = 0 val thisVar = 0
val argumentsVar = 1
visitCode() visitCode()
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
visitVarInsn(Opcodes.ALOAD, 0) visitLoadObjectVar(thisVar)
visitVarInsn(Opcodes.ALOAD, 1) visitLoadObjectVar(argumentsVar)
visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) visitInvokeVirtual(owner = slashesClassName, name = "invoke", descriptor = "(L$MAP_CLASS;)L$T_CLASS;")
visitInsn(Opcodes.ARETURN) visitReturnObject()
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
@ -179,127 +164,123 @@ internal class AsmBuilder<T>(
thisVar thisVar
) )
visitMaxs(2, 2) visitMaxs(0, 2)
visitEnd() visitEnd()
} }
asmCompiledClassWriter.visitEnd() visitEnd()
}
val new = classLoader val new = classLoader
.defineClass(className, asmCompiledClassWriter.toByteArray()) .defineClass(className, classWriter.toByteArray())
.constructors .constructors
.first() .first()
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T> .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression<T>
generatedInstance = new
return new return new
} }
fun visitLoadFromConstants(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
visitNumberConstant(value as Number) loadNumberConstant(value as Number)
visitCastToT() invokeMethodVisitor.visitCheckCast(T_CLASS)
return return
} }
visitLoadAnyFromConstants(value as Any, T_CLASS) loadConstant(value as Any, T_CLASS)
} }
private fun visitLoadAnyFromConstants(value: Any, type: String) { private fun loadConstant(value: Any, type: String) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
maxStack++
invokeMethodVisitor.run { invokeMethodVisitor.run {
visitLoadThis() loadThis()
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;") visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;")
visitLdcOrIConstInsn(idx) visitLdcOrIntConstant(idx)
visitInsn(Opcodes.AALOAD) visitGetObjectArrayElement()
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) invokeMethodVisitor.visitCheckCast(type)
} }
} }
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
fun visitNumberConstant(value: Number) { internal fun loadNumberConstant(value: Number) {
maxStack++
val clazz = value.javaClass val clazz = value.javaClass
val c = clazz.name.replace('.', '/') val c = clazz.name.replace('.', '/')
val sigLetter = SIGNATURE_LETTERS[clazz] val sigLetter = SIGNATURE_LETTERS[clazz]
if (sigLetter != null) { if (sigLetter != null) {
when (value) { when (value) {
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value)
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value)
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value)
else -> invokeMethodVisitor.visitLdcInsn(value) else -> invokeMethodVisitor.visitLdcInsn(value)
} }
invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
return return
} }
visitLoadAnyFromConstants(value, c) loadConstant(value, c)
} }
fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
maxStack += 2 visitLoadObjectVar(invokeArgumentsVar)
visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar)
if (defaultValue != null) { if (defaultValue != null) {
visitLdcInsn(name) visitLdcInsn(name)
visitLoadFromConstants(defaultValue) loadTConstant(defaultValue)
visitMethodInsn( visitInvokeInterface(
Opcodes.INVOKEINTERFACE, owner = MAP_CLASS,
MAP_CLASS, name = "getOrDefault",
"getOrDefault", descriptor = "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;"
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
true
) )
visitCastToT() invokeMethodVisitor.visitCheckCast(T_CLASS)
return return
} }
visitLdcInsn(name) visitLdcInsn(name)
visitMethodInsn(
Opcodes.INVOKEINTERFACE,
MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true
)
visitCastToT()
}
fun visitLoadAlgebra() { visitInvokeInterface(
maxStack++ owner = MAP_CLASS,
invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) name = "get",
descriptor = "(L$OBJECT_CLASS;)L$OBJECT_CLASS;"
invokeMethodVisitor.visitFieldInsn(
Opcodes.GETFIELD,
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
) )
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) invokeMethodVisitor.visitCheckCast(T_CLASS)
} }
fun visitAlgebraOperation( internal fun loadAlgebra() {
loadThis()
invokeMethodVisitor.visitGetField(
owner = FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
name = "algebra",
descriptor = "L$ALGEBRA_CLASS;"
)
invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS)
}
internal fun invokeAlgebraOperation(
owner: String, owner: String,
method: String, method: String,
descriptor: String, descriptor: String,
opcode: Int = Opcodes.INVOKEINTERFACE, opcode: Int = Opcodes.INVOKEINTERFACE,
isInterface: Boolean = true isInterface: Boolean = true
) { ) {
maxStack++
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
visitCastToT() invokeMethodVisitor.visitCheckCast(T_CLASS)
} }
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
fun visitStringConstant(string: String) { internal companion object {
invokeMethodVisitor.visitLdcInsn(string) private val SIGNATURE_LETTERS: Map<Class<out Any>, String> by lazy {
} mapOf(
companion object {
private val SIGNATURE_LETTERS = mapOf(
java.lang.Byte::class.java to "B", java.lang.Byte::class.java to "B",
java.lang.Short::class.java to "S", java.lang.Short::class.java to "S",
java.lang.Integer::class.java to "I", java.lang.Integer::class.java to "I",
@ -307,15 +288,16 @@ internal class AsmBuilder<T>(
java.lang.Float::class.java to "F", java.lang.Float::class.java to "F",
java.lang.Double::class.java to "D" java.lang.Double::class.java to "D"
) )
}
private val INLINABLE_NUMBERS = SIGNATURE_LETTERS.keys private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/FunctionalCompiledExpression" internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
const val MAP_CLASS = "java/util/Map" "scientifik/kmath/asm/internal/AsmCompiledExpression"
const val OBJECT_CLASS = "java/lang/Object"
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" internal const val MAP_CLASS = "java/util/Map"
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" internal const val OBJECT_CLASS = "java/lang/Object"
const val STRING_CLASS = "java/lang/String" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
const val NUMBER_CLASS = "java/lang/Number" internal const val STRING_CLASS = "java/lang/String"
} }
} }

View File

@ -0,0 +1,12 @@
package scientifik.kmath.asm.internal
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
internal abstract class AsmCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}

View File

@ -0,0 +1,15 @@
package scientifik.kmath.asm.internal
import scientifik.kmath.ast.MST
internal fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}

View File

@ -0,0 +1,15 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block)
internal inline fun ClassWriter.visitMethod(
access: Int,
name: String,
descriptor: String,
signature: String?,
exceptions: Array<String>?,
block: MethodVisitor.() -> Unit
): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block)

View File

@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Opcodes.*
internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value) {
-1 -> visitInsn(ICONST_M1) -1 -> visitInsn(ICONST_M1)
0 -> visitInsn(ICONST_0) 0 -> visitInsn(ICONST_0)
1 -> visitInsn(ICONST_1) 1 -> visitInsn(ICONST_1)
@ -14,15 +14,39 @@ internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) {
else -> visitLdcInsn(value) else -> visitLdcInsn(value)
} }
internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when (value) {
0.0 -> visitInsn(DCONST_0) 0.0 -> visitInsn(DCONST_0)
1.0 -> visitInsn(DCONST_1) 1.0 -> visitInsn(DCONST_1)
else -> visitLdcInsn(value) else -> visitLdcInsn(value)
} }
internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) {
0f -> visitInsn(FCONST_0) 0f -> visitInsn(FCONST_0)
1f -> visitInsn(FCONST_1) 1f -> visitInsn(FCONST_1)
2f -> visitInsn(FCONST_2) 2f -> visitInsn(FCONST_2)
else -> visitLdcInsn(value) else -> visitLdcInsn(value)
} }
internal fun MethodVisitor.visitInvokeInterface(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKEINTERFACE, owner, name, descriptor, true)
internal fun MethodVisitor.visitInvokeVirtual(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKEVIRTUAL, owner, name, descriptor, false)
internal fun MethodVisitor.visitInvokeStatic(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKESTATIC, owner, name, descriptor, false)
internal fun MethodVisitor.visitInvokeSpecial(owner: String, name: String, descriptor: String): Unit =
visitMethodInsn(INVOKESPECIAL, owner, name, descriptor, false)
internal fun MethodVisitor.visitCheckCast(type: String): Unit = visitTypeInsn(CHECKCAST, type)
internal fun MethodVisitor.visitGetField(owner: String, name: String, descriptor: String): Unit =
visitFieldInsn(GETFIELD, owner, name, descriptor)
internal fun MethodVisitor.visitLoadObjectVar(`var`: Int): Unit = visitVarInsn(ALOAD, `var`)
internal fun MethodVisitor.visitGetObjectArrayElement(): Unit = visitInsn(AALOAD)
internal fun MethodVisitor.visitReturn(): Unit = visitInsn(RETURN)
internal fun MethodVisitor.visitReturnObject(): Unit = visitInsn(ARETURN)

View File

@ -29,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
append("L${AsmBuilder.OBJECT_CLASS};") append("L${AsmBuilder.OBJECT_CLASS};")
} }
visitAlgebraOperation( invokeAlgebraOperation(
owner = owner, owner = owner,
method = aName, method = aName,
descriptor = sig, descriptor = sig,
@ -39,8 +39,3 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
return true return true
} }
//
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
// val a = tryEvaluate()
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
//}