Make AsmCompiledExpression fields private, add builder functions not to expose AsmGenerationContext to public API, refactor code
This commit is contained in:
parent
6ac0297530
commit
013030951e
@ -0,0 +1,31 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
open class AsmExpressionSpace<T>(space: Space<T>) : Space<AsmExpression<T>>,
|
||||
ExpressionSpace<T, AsmExpression<T>> {
|
||||
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
||||
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
||||
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmSumExpression(a, b)
|
||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(a, k)
|
||||
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
|
||||
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
|
||||
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
|
||||
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
|
||||
}
|
||||
|
||||
class AsmExpressionField<T>(private val field: Field<T>) : ExpressionField<T, AsmExpression<T>>,
|
||||
AsmExpressionSpace<T>(field) {
|
||||
override val one: AsmExpression<T>
|
||||
get() = const(this.field.one)
|
||||
|
||||
override fun number(value: Number): AsmExpression<T> = const(field.run { one * value })
|
||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmProductExpression(a, b)
|
||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmDivExpression(a, b)
|
||||
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
|
||||
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
|
||||
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
|
||||
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
|
||||
}
|
@ -1,307 +1,14 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.Label
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes.*
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
abstract class AsmCompiledExpression<T> internal constructor(
|
||||
@JvmField val algebra: Algebra<T>,
|
||||
@JvmField val constants: MutableList<out Any>
|
||||
@JvmField private val algebra: Algebra<T>,
|
||||
@JvmField private val constants: MutableList<out Any>
|
||||
) : Expression<T> {
|
||||
abstract override fun invoke(arguments: Map<String, T>): T
|
||||
}
|
||||
|
||||
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(expression, collision + 1)
|
||||
}
|
||||
|
||||
class AsmGenerationContext<T>(
|
||||
classOfT: Class<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
private val className: String
|
||||
) {
|
||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||
}
|
||||
|
||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
||||
|
||||
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
||||
private val invokeThisVar: Int = 0
|
||||
private val invokeArgumentsVar: Int = 1
|
||||
private var maxStack: Int = 0
|
||||
private val constants: MutableList<Any> = mutableListOf()
|
||||
private val asmCompiledClassWriter: ClassWriter = ClassWriter(0)
|
||||
private val invokeMethodVisitor: MethodVisitor
|
||||
private val invokeL0: Label
|
||||
private lateinit var invokeL1: Label
|
||||
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||
|
||||
init {
|
||||
asmCompiledClassWriter.visit(
|
||||
V1_8,
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||
slashesClassName,
|
||||
"L$ASM_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
||||
ASM_COMPILED_EXPRESSION_CLASS,
|
||||
arrayOf()
|
||||
)
|
||||
|
||||
asmCompiledClassWriter.run {
|
||||
visitMethod(ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run {
|
||||
val thisVar = 0
|
||||
val algebraVar = 1
|
||||
val constantsVar = 2
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(ALOAD, thisVar)
|
||||
visitVarInsn(ALOAD, algebraVar)
|
||||
visitVarInsn(ALOAD, constantsVar)
|
||||
|
||||
visitMethodInsn(
|
||||
INVOKESPECIAL,
|
||||
ASM_COMPILED_EXPRESSION_CLASS,
|
||||
"<init>",
|
||||
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
||||
false
|
||||
)
|
||||
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
visitInsn(RETURN)
|
||||
val l2 = Label()
|
||||
visitLabel(l2)
|
||||
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
|
||||
|
||||
visitLocalVariable(
|
||||
"algebra",
|
||||
"L$ALGEBRA_CLASS;",
|
||||
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
|
||||
l0,
|
||||
l2,
|
||||
algebraVar
|
||||
)
|
||||
|
||||
visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS<L$T_CLASS;>;", l0, l2, constantsVar)
|
||||
visitMaxs(3, 3)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
invokeMethodVisitor = visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL,
|
||||
"invoke",
|
||||
"(L$MAP_CLASS;)L$T_CLASS;",
|
||||
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||
null
|
||||
)
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitCode()
|
||||
invokeL0 = Label()
|
||||
visitLabel(invokeL0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@PublishedApi
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
internal fun generate(): AsmCompiledExpression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitInsn(ARETURN)
|
||||
invokeL1 = Label()
|
||||
visitLabel(invokeL1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
invokeL0,
|
||||
invokeL1,
|
||||
invokeThisVar
|
||||
)
|
||||
|
||||
visitLocalVariable(
|
||||
"arguments",
|
||||
"L$MAP_CLASS;",
|
||||
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
||||
invokeL0,
|
||||
invokeL1,
|
||||
invokeArgumentsVar
|
||||
)
|
||||
|
||||
visitMaxs(maxStack + 1, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||
"invoke",
|
||||
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||
null,
|
||||
null
|
||||
).run {
|
||||
val thisVar = 0
|
||||
visitCode()
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(ALOAD, 0)
|
||||
visitVarInsn(ALOAD, 1)
|
||||
visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
||||
visitInsn(ARETURN)
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
l0,
|
||||
l1,
|
||||
thisVar
|
||||
)
|
||||
|
||||
visitMaxs(2, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitEnd()
|
||||
|
||||
val new = classLoader
|
||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, constants) as AsmCompiledExpression<T>
|
||||
|
||||
generatedInstance = new
|
||||
return new
|
||||
}
|
||||
|
||||
internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS)
|
||||
|
||||
private fun visitLoadAnyFromConstants(value: Any, type: String) {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
maxStack++
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitLoadThis()
|
||||
visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
|
||||
visitLdcOrIConstInsn(idx)
|
||||
visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
|
||||
invokeMethodVisitor.visitTypeInsn(CHECKCAST, type)
|
||||
}
|
||||
}
|
||||
|
||||
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar)
|
||||
|
||||
internal fun visitNumberConstant(value: Number) {
|
||||
maxStack++
|
||||
val clazz = value.javaClass
|
||||
val c = clazz.name.replace('.', '/')
|
||||
|
||||
val sigLetter = SIGNATURE_LETTERS[clazz]
|
||||
|
||||
if (sigLetter != null) {
|
||||
when (value) {
|
||||
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value)
|
||||
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value)
|
||||
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value)
|
||||
else -> invokeMethodVisitor.visitLdcInsn(value)
|
||||
}
|
||||
|
||||
invokeMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
|
||||
return
|
||||
}
|
||||
|
||||
visitLoadAnyFromConstants(value, c)
|
||||
}
|
||||
|
||||
internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run {
|
||||
maxStack += 2
|
||||
visitVarInsn(ALOAD, invokeArgumentsVar)
|
||||
|
||||
if (defaultValue != null) {
|
||||
visitLdcInsn(name)
|
||||
visitLoadFromConstants(defaultValue)
|
||||
|
||||
visitMethodInsn(
|
||||
INVOKEINTERFACE,
|
||||
MAP_CLASS,
|
||||
"getOrDefault",
|
||||
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
|
||||
true
|
||||
)
|
||||
|
||||
visitCastToT()
|
||||
return
|
||||
}
|
||||
|
||||
visitLdcInsn(name)
|
||||
visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
internal fun visitLoadAlgebra() {
|
||||
invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar)
|
||||
|
||||
invokeMethodVisitor.visitFieldInsn(
|
||||
GETFIELD,
|
||||
ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||
)
|
||||
|
||||
invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS)
|
||||
}
|
||||
|
||||
internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) {
|
||||
maxStack++
|
||||
invokeMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
||||
|
||||
internal companion object {
|
||||
private val SIGNATURE_LETTERS = mapOf(
|
||||
java.lang.Byte::class.java to "B",
|
||||
java.lang.Short::class.java to "S",
|
||||
java.lang.Integer::class.java to "I",
|
||||
java.lang.Long::class.java to "J",
|
||||
java.lang.Float::class.java to "F",
|
||||
java.lang.Double::class.java to "D"
|
||||
)
|
||||
|
||||
internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression"
|
||||
internal const val LIST_CLASS = "java/util/List"
|
||||
internal const val MAP_CLASS = "java/util/Map"
|
||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||
internal const val STRING_CLASS = "java/lang/String"
|
||||
internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
||||
internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
||||
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||
}
|
||||
}
|
||||
|
||||
interface AsmExpression<T> {
|
||||
fun invoke(gen: AsmGenerationContext<T>)
|
||||
}
|
||||
@ -383,35 +90,3 @@ internal class AsmDivExpression<T>(
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
open class AsmExpressionSpace<T>(space: Space<T>) : Space<AsmExpression<T>>,
|
||||
ExpressionSpace<T, AsmExpression<T>> {
|
||||
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
||||
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
||||
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmSumExpression(a, b)
|
||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(a, k)
|
||||
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
|
||||
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
|
||||
operator fun T.plus(arg: AsmExpression<T>) = arg + this
|
||||
operator fun T.minus(arg: AsmExpression<T>) = arg - this
|
||||
}
|
||||
|
||||
class AsmExpressionField<T>(private val field: Field<T>) : ExpressionField<T, AsmExpression<T>>,
|
||||
AsmExpressionSpace<T>(field) {
|
||||
override val one: AsmExpression<T>
|
||||
get() = const(this.field.one)
|
||||
|
||||
override fun number(value: Number): AsmExpression<T> = const(field.run { one * value })
|
||||
|
||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmProductExpression(a, b)
|
||||
|
||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmDivExpression(a, b)
|
||||
|
||||
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
|
||||
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
|
||||
operator fun T.times(arg: AsmExpression<T>) = arg * this
|
||||
operator fun T.div(arg: AsmExpression<T>) = arg / this
|
||||
}
|
||||
|
@ -0,0 +1,283 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.Label
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
class AsmGenerationContext<T>(
|
||||
classOfT: Class<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
private val className: String
|
||||
) {
|
||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||
}
|
||||
|
||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
||||
|
||||
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
||||
private val invokeThisVar: Int = 0
|
||||
private val invokeArgumentsVar: Int = 1
|
||||
private var maxStack: Int = 0
|
||||
private val constants: MutableList<Any> = mutableListOf()
|
||||
private val asmCompiledClassWriter: ClassWriter =
|
||||
ClassWriter(0)
|
||||
private val invokeMethodVisitor: MethodVisitor
|
||||
private val invokeL0: Label
|
||||
private lateinit var invokeL1: Label
|
||||
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||
|
||||
init {
|
||||
asmCompiledClassWriter.visit(
|
||||
Opcodes.V1_8,
|
||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
||||
slashesClassName,
|
||||
"L$ASM_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
||||
ASM_COMPILED_EXPRESSION_CLASS,
|
||||
arrayOf()
|
||||
)
|
||||
|
||||
asmCompiledClassWriter.run {
|
||||
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run {
|
||||
val thisVar = 0
|
||||
val algebraVar = 1
|
||||
val constantsVar = 2
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(Opcodes.ALOAD, thisVar)
|
||||
visitVarInsn(Opcodes.ALOAD, algebraVar)
|
||||
visitVarInsn(Opcodes.ALOAD, constantsVar)
|
||||
|
||||
visitMethodInsn(
|
||||
Opcodes.INVOKESPECIAL,
|
||||
ASM_COMPILED_EXPRESSION_CLASS,
|
||||
"<init>",
|
||||
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
||||
false
|
||||
)
|
||||
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
visitInsn(Opcodes.RETURN)
|
||||
val l2 = Label()
|
||||
visitLabel(l2)
|
||||
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
|
||||
|
||||
visitLocalVariable(
|
||||
"algebra",
|
||||
"L$ALGEBRA_CLASS;",
|
||||
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
|
||||
l0,
|
||||
l2,
|
||||
algebraVar
|
||||
)
|
||||
|
||||
visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS<L$T_CLASS;>;", l0, l2, constantsVar)
|
||||
visitMaxs(3, 3)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
invokeMethodVisitor = visitMethod(
|
||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
||||
"invoke",
|
||||
"(L$MAP_CLASS;)L$T_CLASS;",
|
||||
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||
null
|
||||
)
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitCode()
|
||||
invokeL0 = Label()
|
||||
visitLabel(invokeL0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@PublishedApi
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
internal fun generate(): AsmCompiledExpression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitInsn(Opcodes.ARETURN)
|
||||
invokeL1 = Label()
|
||||
visitLabel(invokeL1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
invokeL0,
|
||||
invokeL1,
|
||||
invokeThisVar
|
||||
)
|
||||
|
||||
visitLocalVariable(
|
||||
"arguments",
|
||||
"L$MAP_CLASS;",
|
||||
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
||||
invokeL0,
|
||||
invokeL1,
|
||||
invokeArgumentsVar
|
||||
)
|
||||
|
||||
visitMaxs(maxStack + 1, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitMethod(
|
||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
||||
"invoke",
|
||||
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||
null,
|
||||
null
|
||||
).run {
|
||||
val thisVar = 0
|
||||
visitCode()
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(Opcodes.ALOAD, 0)
|
||||
visitVarInsn(Opcodes.ALOAD, 1)
|
||||
visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
||||
visitInsn(Opcodes.ARETURN)
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
l0,
|
||||
l1,
|
||||
thisVar
|
||||
)
|
||||
|
||||
visitMaxs(2, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitEnd()
|
||||
|
||||
val new = classLoader
|
||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, constants) as AsmCompiledExpression<T>
|
||||
|
||||
generatedInstance = new
|
||||
return new
|
||||
}
|
||||
|
||||
internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS)
|
||||
|
||||
private fun visitLoadAnyFromConstants(value: Any, type: String) {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
maxStack++
|
||||
|
||||
invokeMethodVisitor.run {
|
||||
visitLoadThis()
|
||||
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
|
||||
visitLdcOrIConstInsn(idx)
|
||||
visitMethodInsn(Opcodes.INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
|
||||
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type)
|
||||
}
|
||||
}
|
||||
|
||||
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
|
||||
|
||||
internal fun visitNumberConstant(value: Number) {
|
||||
maxStack++
|
||||
val clazz = value.javaClass
|
||||
val c = clazz.name.replace('.', '/')
|
||||
|
||||
val sigLetter = SIGNATURE_LETTERS[clazz]
|
||||
|
||||
if (sigLetter != null) {
|
||||
when (value) {
|
||||
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value)
|
||||
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value)
|
||||
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value)
|
||||
else -> invokeMethodVisitor.visitLdcInsn(value)
|
||||
}
|
||||
|
||||
invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
|
||||
return
|
||||
}
|
||||
|
||||
visitLoadAnyFromConstants(value, c)
|
||||
}
|
||||
|
||||
internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run {
|
||||
maxStack += 2
|
||||
visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar)
|
||||
|
||||
if (defaultValue != null) {
|
||||
visitLdcInsn(name)
|
||||
visitLoadFromConstants(defaultValue)
|
||||
|
||||
visitMethodInsn(
|
||||
Opcodes.INVOKEINTERFACE,
|
||||
MAP_CLASS,
|
||||
"getOrDefault",
|
||||
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
|
||||
true
|
||||
)
|
||||
|
||||
visitCastToT()
|
||||
return
|
||||
}
|
||||
|
||||
visitLdcInsn(name)
|
||||
visitMethodInsn(Opcodes.INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
internal fun visitLoadAlgebra() {
|
||||
invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
|
||||
|
||||
invokeMethodVisitor.visitFieldInsn(
|
||||
Opcodes.GETFIELD,
|
||||
ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||
)
|
||||
|
||||
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS)
|
||||
}
|
||||
|
||||
internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) {
|
||||
maxStack++
|
||||
invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, owner, method, descriptor, true)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS)
|
||||
|
||||
internal companion object {
|
||||
private val SIGNATURE_LETTERS = mapOf(
|
||||
java.lang.Byte::class.java to "B",
|
||||
java.lang.Short::class.java to "S",
|
||||
java.lang.Integer::class.java to "I",
|
||||
java.lang.Long::class.java to "J",
|
||||
java.lang.Float::class.java to "F",
|
||||
java.lang.Double::class.java to "D"
|
||||
)
|
||||
|
||||
internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression"
|
||||
internal const val LIST_CLASS = "java/util/List"
|
||||
internal const val MAP_CLASS = "java/util/Map"
|
||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||
internal const val STRING_CLASS = "java/lang/String"
|
||||
internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
||||
internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
||||
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
@PublishedApi
|
||||
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(expression, collision + 1)
|
||||
}
|
||||
|
||||
|
||||
inline fun <reified T> asmSpace(
|
||||
algebra: Space<T>,
|
||||
block: AsmExpressionSpace<T>.() -> AsmExpression<T>
|
||||
): Expression<T> {
|
||||
val expression = AsmExpressionSpace(algebra).block()
|
||||
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
||||
expression.invoke(ctx)
|
||||
return ctx.generate()
|
||||
}
|
||||
|
||||
inline fun <reified T> asmField(
|
||||
algebra: Field<T>,
|
||||
block: AsmExpressionField<T>.() -> AsmExpression<T>
|
||||
): Expression<T> {
|
||||
val expression = AsmExpressionField(algebra).block()
|
||||
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
||||
expression.invoke(ctx)
|
||||
return ctx.generate()
|
||||
}
|
@ -64,7 +64,7 @@ class AsmTest {
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testCProductWithOtherTypeNumber() = testDoubleExpressionValue(
|
||||
fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue(
|
||||
25.0,
|
||||
AsmConstProductExpression(AsmVariableExpression("x"), 5f),
|
||||
mapOf("x" to 5.0)
|
||||
@ -81,21 +81,21 @@ class AsmTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCProductWithCustomTypeNumber() = testDoubleExpressionValue(
|
||||
fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue(
|
||||
0.0,
|
||||
AsmConstProductExpression(AsmVariableExpression("x"), CustomZero),
|
||||
mapOf("x" to 5.0)
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testVar() = testDoubleExpressionValue(
|
||||
fun testVar(): Unit = testDoubleExpressionValue(
|
||||
10000.0,
|
||||
AsmVariableExpression("x"),
|
||||
mapOf("x" to 10000.0)
|
||||
)
|
||||
|
||||
@Test
|
||||
fun testVarWithDefault() = testDoubleExpressionValue(
|
||||
fun testVarWithDefault(): Unit = testDoubleExpressionValue(
|
||||
10000.0,
|
||||
AsmVariableExpression("x", 10000.0),
|
||||
mapOf()
|
||||
|
Loading…
Reference in New Issue
Block a user