Refactor ASM generation code #105
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.asm
|
||||
|
||||
import scientifik.kmath.asm.internal.AsmBuilder
|
||||
import scientifik.kmath.asm.internal.buildName
|
||||
import scientifik.kmath.asm.internal.hasSpecific
|
||||
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||
import scientifik.kmath.ast.MST
|
||||
@ -11,44 +12,30 @@ import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.*
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
|
||||
/**
|
||||
* Compile given MST to an Expression using AST compiler
|
||||
*/
|
||||
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
|
||||
|
||||
fun buildName(mst: MST, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(mst, collision + 1)
|
||||
}
|
||||
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit {
|
||||
fun AsmBuilder<T>.visit(node: MST) {
|
||||
when (node) {
|
||||
is MST.Symbolic -> visitLoadFromVariables(node.value)
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
is MST.Numeric -> {
|
||||
val constant = if (algebra is NumericAlgebra<T>) {
|
||||
algebra.number(node.value)
|
||||
} else {
|
||||
error("Number literals are not supported in $algebra")
|
||||
}
|
||||
visitLoadFromConstants(constant)
|
||||
loadTConstant(constant)
|
||||
}
|
||||
is MST.Unary -> {
|
||||
visitLoadAlgebra()
|
||||
loadAlgebra()
|
||||
|
||||
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
|
||||
if (!hasSpecific(algebra, node.operation, 1)) loadStringConstant(node.operation)
|
||||
|
||||
visit(node.value)
|
||||
|
||||
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
|
||||
visitAlgebraOperation(
|
||||
invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "unaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
@ -58,17 +45,17 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
}
|
||||
}
|
||||
is MST.Binary -> {
|
||||
visitLoadAlgebra()
|
||||
loadAlgebra()
|
||||
|
||||
if (!hasSpecific(algebra, node.operation, 2))
|
||||
visitStringConstant(node.operation)
|
||||
loadStringConstant(node.operation)
|
||||
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
|
||||
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
|
||||
|
||||
visitAlgebraOperation(
|
||||
invokeAlgebraOperation(
|
||||
owner = AsmBuilder.ALGEBRA_CLASS,
|
||||
method = "binaryOperation",
|
||||
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
|
||||
@ -81,9 +68,7 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
}
|
||||
}
|
||||
|
||||
val builder = AsmBuilder(type.java, algebra, buildName(this))
|
||||
builder.visit(this)
|
||||
return builder.generate()
|
||||
return AsmBuilder(type.java, algebra, buildName(this)) { visit(this@compileWith) }.getInstance()
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
|
||||
|
@ -4,9 +4,8 @@ import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.Label
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes
|
||||
import scientifik.kmath.asm.AsmExpression
|
||||
import scientifik.kmath.asm.FunctionalCompiledExpression
|
||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import scientifik.kmath.ast.MST
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
/**
|
||||
@ -22,7 +21,7 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
||||
private val classOfT: Class<*>,
|
||||
private val algebra: Algebra<T>,
|
||||
private val className: String,
|
||||
private val root: AsmExpression<T>
|
||||
private val evaluateMethodVisitor: AsmBuilder<T>.() -> Unit
|
||||
) {
|
||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||
@ -42,10 +41,10 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
||||
private val invokeArgumentsVar: Int = 1
|
||||
private val constants: MutableList<Any> = mutableListOf()
|
||||
private lateinit var invokeMethodVisitor: MethodVisitor
|
||||
private var generatedInstance: FunctionalCompiledExpression<T>? = null
|
||||
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun getInstance(): FunctionalCompiledExpression<T> {
|
||||
fun getInstance(): AsmCompiledExpression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||
@ -112,7 +111,7 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
||||
visitCode()
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
root.compile(this@AsmBuilder)
|
||||
evaluateMethodVisitor()
|
||||
visitReturnObject()
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
@ -178,7 +177,7 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
||||
.defineClass(className, classWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
|
||||
.newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression<T>
|
||||
|
||||
generatedInstance = new
|
||||
return new
|
||||
@ -296,13 +295,11 @@ internal class AsmBuilder<T> @PublishedApi internal constructor(
|
||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||
|
||||
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
||||
"scientifik/kmath/asm/FunctionalCompiledExpression"
|
||||
"scientifik/kmath/asm/internal/AsmCompiledExpression"
|
||||
|
||||
internal const val MAP_CLASS = "java/util/Map"
|
||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||
internal const val STRING_CLASS = "java/lang/String"
|
||||
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,12 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import scientifik.kmath.expressions.Expression
|
||||
import scientifik.kmath.operations.Algebra
|
||||
|
||||
internal abstract class AsmCompiledExpression<T> internal constructor(
|
||||
@JvmField protected val algebra: Algebra<T>,
|
||||
@JvmField protected val constants: Array<Any>
|
||||
) : Expression<T> {
|
||||
abstract override fun invoke(arguments: Map<String, T>): T
|
||||
}
|
||||
|
@ -0,0 +1,15 @@
|
||||
package scientifik.kmath.asm.internal
|
||||
|
||||
import scientifik.kmath.ast.MST
|
||||
|
||||
internal fun buildName(mst: MST, collision: Int = 0): String {
|
||||
val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
|
||||
|
||||
try {
|
||||
Class.forName(name)
|
||||
} catch (ignored: ClassNotFoundException) {
|
||||
return name
|
||||
}
|
||||
|
||||
return buildName(mst, collision + 1)
|
||||
}
|
@ -29,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
append("L${AsmBuilder.OBJECT_CLASS};")
|
||||
}
|
||||
|
||||
visitAlgebraOperation(
|
||||
invokeAlgebraOperation(
|
||||
owner = owner,
|
||||
method = aName,
|
||||
descriptor = sig,
|
||||
@ -39,8 +39,3 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
|
||||
return true
|
||||
}
|
||||
//
|
||||
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||
// val a = tryEvaluate()
|
||||
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
|
||||
//}
|
||||
|
Loading…
Reference in New Issue
Block a user