Refactor ASM generation code #105
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.asm
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder
|
import scientifik.kmath.asm.internal.AsmBuilder
|
||||||
|
import scientifik.kmath.asm.internal.buildName
|
||||||
import scientifik.kmath.asm.internal.hasSpecific
|
import scientifik.kmath.asm.internal.hasSpecific
|
||||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||||
import scientifik.kmath.ast.MST
|
import scientifik.kmath.ast.MST
|
||||||
@ -11,44 +12,30 @@ import scientifik.kmath.expressions.Expression
|
|||||||
import scientifik.kmath.operations.*
|
import scientifik.kmath.operations.*
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to an Expression using AST compiler
|
* Compile given MST to an Expression using AST compiler
|
||||||
*/
|
*/
|
||||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
|
fun AsmBuilder<T>.visit(node: MST) {
|
||||||
fun buildName(mst: MST, collision: Int = 0): String {
|
|
||||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
|
||||||
|
|
||||||
try {
|
|
||||||
Class.forName(name)
|
|
||||||
} catch (ignored: ClassNotFoundException) {
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
return buildName(mst, collision + 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun AsmBuilder<T>.visit(node: MST): Unit {
|
|
||||||
when (node) {
|
when (node) {
|
||||||
is MST.Symbolic -> visitLoadFromVariables(node.value)
|
is MST.Symbolic -> loadVariable(node.value)
|
||||||
is MST.Numeric -> {
|
is MST.Numeric -> {
|
||||||
val constant = if (algebra is NumericAlgebra<T>) {
|
val constant = if (algebra is NumericAlgebra<T>) {
|
||||||
algebra.number(node.value)
|
algebra.number(node.value)
|
||||||
} else {
|
} else {
|
||||||
error("Number literals are not supported in $algebra")
|
error("Number literals are not supported in $algebra")
|
||||||
}
|
}
|
||||||
visitLoadFromConstants(constant)
|
loadTConstant(constant)
|
||||||
}
|
}
|
||||||
is MST.Unary -> {
|
is MST.Unary -> {
|
||||||
visitLoadAlgebra()
|
loadAlgebra()
|
||||||
|
|
||||||
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
|
if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
|
||||||
|
|
||||||
visit(node.value)
|
visit(node.value)
|
||||||
|
|
||||||
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
|
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
|
||||||
visitAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||||
method = "unaryOperation",
|
method = "unaryOperation",
|
||||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||||
@ -58,17 +45,17 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
is MST.Binary -> {
|
is MST.Binary -> {
|
||||||
visitLoadAlgebra()
|
loadAlgebra()
|
||||||
|
|
||||||
if (!hasSpecific(algebra, node.operation, 2))
|
if (!hasSpecific(algebra, node.operation, 2))
|
||||||
visitStringConstant(node.operation)
|
loadStringConstant(node.operation)
|
||||||
|
|
||||||
visit(node.left)
|
visit(node.left)
|
||||||
visit(node.right)
|
visit(node.right)
|
||||||
|
|
||||||
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
|
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
|
||||||
|
|
||||||
visitAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||||
method = "binaryOperation",
|
method = "binaryOperation",
|
||||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||||
@ -81,9 +68,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val builder = AsmBuilder(type.java, algebra, buildName(this))
|
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||||
builder.visit(this)
|
|
||||||
return builder.generate()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||||
|
@ -4,9 +4,8 @@ import org.objectweb.asm.ClassWriter
|
|||||||
import org.objectweb.asm.Label
|
import org.objectweb.asm.Label
|
||||||
import org.objectweb.asm.MethodVisitor
|
import org.objectweb.asm.MethodVisitor
|
||||||
import org.objectweb.asm.Opcodes
|
import org.objectweb.asm.Opcodes
|
||||||
import scientifik.kmath.asm.AsmExpression
|
|
||||||
import scientifik.kmath.asm.FunctionalCompiledExpression
|
|
||||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -22,7 +21,7 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
|||||||
private val classOfT: Class<*>,
|
private val classOfT: Class<*>,
|
||||||
private val algebra: Algebra<T>,
|
private val algebra: Algebra<T>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val root: AsmExpression<T>
|
private val evaluateMethodVisitor: AsmBuilder<T>.() -> Unit
|
||||||
) {
|
) {
|
||||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||||
@ -42,10 +41,10 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
|||||||
private val invokeArgumentsVar: Int = 1
|
private val invokeArgumentsVar: Int = 1
|
||||||
private val constants: MutableList<Any> = mutableListOf()
|
private val constants: MutableList<Any> = mutableListOf()
|
||||||
private lateinit var invokeMethodVisitor: MethodVisitor
|
private lateinit var invokeMethodVisitor: MethodVisitor
|
||||||
private var generatedInstance: FunctionalCompiledExpression<T>? = null
|
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
fun getInstance(): FunctionalCompiledExpression<T> {
|
fun getInstance(): AsmCompiledExpression<T> {
|
||||||
generatedInstance?.let { return it }
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||||
@ -112,7 +111,7 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
|||||||
visitCode()
|
visitCode()
|
||||||
val l0 = Label()
|
val l0 = Label()
|
||||||
visitLabel(l0)
|
visitLabel(l0)
|
||||||
root.compile(this@AsmBuilder)
|
evaluateMethodVisitor()
|
||||||
visitReturnObject()
|
visitReturnObject()
|
||||||
val l1 = Label()
|
val l1 = Label()
|
||||||
visitLabel(l1)
|
visitLabel(l1)
|
||||||
@ -178,7 +177,7 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
|||||||
.defineClass(className, classWriter.toByteArray())
|
.defineClass(className, classWriter.toByteArray())
|
||||||
.constructors
|
.constructors
|
||||||
.first()
|
.first()
|
||||||
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
|
.newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression<T>
|
||||||
|
|
||||||
generatedInstance = new
|
generatedInstance = new
|
||||||
return new
|
return new
|
||||||
@ -296,13 +295,11 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
|||||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
|
||||||
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
||||||
"scientifik/kmath/asm/FunctionalCompiledExpression"
|
"scientifik/kmath/asm/internal/AsmCompiledExpression"
|
||||||
|
|
||||||
internal const val MAP_CLASS = "java/util/Map"
|
internal const val MAP_CLASS = "java/util/Map"
|
||||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||||
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
|
||||||
internal const val STRING_CLASS = "java/lang/String"
|
internal const val STRING_CLASS = "java/lang/String"
|
||||||
internal const val NUMBER_CLASS = "java/lang/Number"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
|
||||||
|
internal abstract class AsmCompiledExpression<T> internal constructor(
|
||||||
|
@JvmField protected val algebra: Algebra<T>,
|
||||||
|
@JvmField protected val constants: Array<Any>
|
||||||
|
) : Expression<T> {
|
||||||
|
abstract override fun invoke(arguments: Map<String, T>): T
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,15 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
|
||||||
|
internal fun buildName(mst: MST, collision: Int = 0): String {
|
||||||
|
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||||
|
|
||||||
|
try {
|
||||||
|
Class.forName(name)
|
||||||
|
} catch (ignored: ClassNotFoundException) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildName(mst, collision + 1)
|
||||||
|
}
|
@ -29,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
|||||||
append("L${AsmBuilder.OBJECT_CLASS};")
|
append("L${AsmBuilder.OBJECT_CLASS};")
|
||||||
}
|
}
|
||||||
|
|
||||||
visitAlgebraOperation(
|
invokeAlgebraOperation(
|
||||||
owner = owner,
|
owner = owner,
|
||||||
method = aName,
|
method = aName,
|
||||||
descriptor = sig,
|
descriptor = sig,
|
||||||
@ -39,8 +39,3 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
//
|
|
||||||
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
|
|
||||||
// val a = tryEvaluate()
|
|
||||||
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
|
|
||||||
//}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user