Dev #127
9
kmath-asm/build.gradle.kts
Normal file
9
kmath-asm/build.gradle.kts
Normal file
@ -0,0 +1,9 @@
|
||||
plugins {
|
||||
id("scientifik.jvm")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api(project(path = ":kmath-core"))
|
||||
api("org.ow2.asm:asm:8.0.1")
|
||||
api("org.ow2.asm:asm-commons:8.0.1")
|
||||
}
|
@ -0,0 +1,385 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.objectweb.asm.ClassWriter
|
||||
import org.objectweb.asm.Label
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes.*
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Space
|
||||
import java.io.File
|
||||
|
||||
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
|
||||
abstract fun evaluate(arguments: Map<String, T>): T
|
||||
}
|
||||
|
||||
class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T>, private val className: String) {
|
||||
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||
}
|
||||
|
||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
|
||||
|
||||
@Suppress("PrivatePropertyName")
|
||||
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
||||
private val constants: MutableList<T> = mutableListOf()
|
||||
private val asmCompiledClassWriter = ClassWriter(0)
|
||||
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
||||
private val evaluateMethodVisitor: MethodVisitor
|
||||
private val evaluateThisVar: Int = 0
|
||||
private val evaluateArgumentsVar: Int = 1
|
||||
private var evaluateL0: Label
|
||||
private lateinit var evaluateL1: Label
|
||||
var maxStack: Int = 0
|
||||
|
||||
init {
|
||||
asmCompiledClassWriter.visit(
|
||||
V1_8,
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||
slashesClassName,
|
||||
"L$ASM_COMPILED_CLASS<L$T_CLASS;>;",
|
||||
ASM_COMPILED_CLASS,
|
||||
arrayOf()
|
||||
)
|
||||
|
||||
asmCompiledClassWriter.run {
|
||||
visitMethod(ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run {
|
||||
val thisVar = 0
|
||||
val algebraVar = 1
|
||||
val constantsVar = 2
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(ALOAD, thisVar)
|
||||
visitVarInsn(ALOAD, algebraVar)
|
||||
visitVarInsn(ALOAD, constantsVar)
|
||||
|
||||
visitMethodInsn(
|
||||
INVOKESPECIAL,
|
||||
ASM_COMPILED_CLASS,
|
||||
"<init>",
|
||||
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
||||
false
|
||||
)
|
||||
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
visitInsn(RETURN)
|
||||
val l2 = Label()
|
||||
visitLabel(l2)
|
||||
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
|
||||
|
||||
visitLocalVariable(
|
||||
"algebra",
|
||||
"L$ALGEBRA_CLASS;",
|
||||
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
|
||||
l0,
|
||||
l2,
|
||||
algebraVar
|
||||
)
|
||||
|
||||
visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS<L$T_CLASS;>;", l0, l2, constantsVar)
|
||||
visitMaxs(3, 3)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
evaluateMethodVisitor = visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL,
|
||||
"evaluate",
|
||||
"(L$MAP_CLASS;)L$T_CLASS;",
|
||||
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||
null
|
||||
)
|
||||
|
||||
evaluateMethodVisitor.run {
|
||||
visitCode()
|
||||
evaluateL0 = Label()
|
||||
visitLabel(evaluateL0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun generate(): AsmCompiled<T> {
|
||||
evaluateMethodVisitor.run {
|
||||
visitInsn(ARETURN)
|
||||
evaluateL1 = Label()
|
||||
visitLabel(evaluateL1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
evaluateL0,
|
||||
evaluateL1,
|
||||
evaluateThisVar
|
||||
)
|
||||
|
||||
visitLocalVariable(
|
||||
"arguments",
|
||||
"L$MAP_CLASS;",
|
||||
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
||||
evaluateL0,
|
||||
evaluateL1,
|
||||
evaluateArgumentsVar
|
||||
)
|
||||
|
||||
visitMaxs(maxStack + 1, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitMethod(
|
||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||
"evaluate",
|
||||
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||
null,
|
||||
null
|
||||
).run {
|
||||
val thisVar = 0
|
||||
visitCode()
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
visitVarInsn(ALOAD, 0)
|
||||
visitVarInsn(ALOAD, 1)
|
||||
visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "evaluate", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
||||
visitInsn(ARETURN)
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
"L$slashesClassName;",
|
||||
T_CLASS,
|
||||
l0,
|
||||
l1,
|
||||
thisVar
|
||||
)
|
||||
|
||||
visitMaxs(2, 2)
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
asmCompiledClassWriter.visitEnd()
|
||||
|
||||
return classLoader
|
||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||
.constructors
|
||||
.first()
|
||||
.newInstance(algebra, constants) as AsmCompiled<T>
|
||||
}
|
||||
|
||||
fun visitLoadFromConstants(value: T) {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
maxStack++
|
||||
|
||||
evaluateMethodVisitor.run {
|
||||
visitLoadThis()
|
||||
visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
|
||||
visitLdcOrIConstInsn(idx)
|
||||
visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
|
||||
visitCastToT()
|
||||
}
|
||||
}
|
||||
|
||||
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
||||
|
||||
fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value)
|
||||
|
||||
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
|
||||
maxStack++
|
||||
visitVarInsn(ALOAD, evaluateArgumentsVar)
|
||||
|
||||
if (defaultValue != null) {
|
||||
visitLdcInsn(name)
|
||||
visitLoadFromConstants(defaultValue)
|
||||
|
||||
visitMethodInsn(
|
||||
INVOKEINTERFACE,
|
||||
MAP_CLASS,
|
||||
"getOrDefault",
|
||||
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
|
||||
true
|
||||
)
|
||||
|
||||
visitCastToT()
|
||||
return
|
||||
}
|
||||
|
||||
visitLdcInsn(name)
|
||||
visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
fun visitLoadAlgebra() {
|
||||
evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
|
||||
|
||||
evaluateMethodVisitor.visitFieldInsn(
|
||||
GETFIELD,
|
||||
ASM_COMPILED_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||
)
|
||||
|
||||
evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS)
|
||||
}
|
||||
|
||||
fun visitInvokeAlgebraOperation(owner: String, method: String, descriptor: String) {
|
||||
maxStack++
|
||||
evaluateMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true)
|
||||
visitCastToT()
|
||||
}
|
||||
|
||||
fun visitCastToT() {
|
||||
evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
|
||||
const val LIST_CLASS = "java/util/List"
|
||||
const val MAP_CLASS = "java/util/Map"
|
||||
const val OBJECT_CLASS = "java/lang/Object"
|
||||
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
const val SPACE_CLASS = "scientifik/kmath/operations/Space"
|
||||
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||
const val FIELD_CLASS = "scientifik/kmath/operations/Field"
|
||||
const val STRING_CLASS = "java/lang/String"
|
||||
}
|
||||
}
|
||||
|
||||
interface AsmExpression<T> {
|
||||
fun invoke(gen: AsmGenerationContext<T>)
|
||||
}
|
||||
|
||||
internal class AsmVariableExpression<T>(val name: String, val default: T? = null) :
|
||||
AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadFromVariables(name, default)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmConstantExpression<T>(val value: T) :
|
||||
AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadFromConstants(value)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmSumExpression<T>(
|
||||
val first: AsmExpression<T>,
|
||||
val second: AsmExpression<T>
|
||||
) : AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadAlgebra()
|
||||
first.invoke(gen)
|
||||
second.invoke(gen)
|
||||
|
||||
gen.visitInvokeAlgebraOperation(
|
||||
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
||||
method = "add",
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmProductExpression<T>(
|
||||
val first: AsmExpression<T>,
|
||||
val second: AsmExpression<T>
|
||||
) : AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadAlgebra()
|
||||
first.invoke(gen)
|
||||
second.invoke(gen)
|
||||
|
||||
gen.visitInvokeAlgebraOperation(
|
||||
owner = AsmGenerationContext.SPACE_CLASS,
|
||||
method = "times",
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmConstProductExpression<T>(
|
||||
val expr: AsmExpression<T>,
|
||||
val const: Number
|
||||
) : AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadAlgebra()
|
||||
expr.invoke(gen)
|
||||
gen.visitNumberConstant(const)
|
||||
|
||||
gen.visitInvokeAlgebraOperation(
|
||||
owner = AsmGenerationContext.SPACE_CLASS,
|
||||
method = "multiply",
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
internal class AsmDivExpression<T>(
|
||||
val expr: AsmExpression<T>,
|
||||
val second: AsmExpression<T>
|
||||
) : AsmExpression<T> {
|
||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||
gen.visitLoadAlgebra()
|
||||
expr.invoke(gen)
|
||||
second.invoke(gen)
|
||||
|
||||
gen.visitInvokeAlgebraOperation(
|
||||
owner = AsmGenerationContext.FIELD_CLASS,
|
||||
method = "divide",
|
||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
open class AsmFunctionalExpressionSpace<T>(
|
||||
val space: Space<T>,
|
||||
one: T
|
||||
) : Space<AsmExpression<T>>,
|
||||
ExpressionSpace<T, AsmExpression<T>> {
|
||||
override val zero: AsmExpression<T> =
|
||||
AsmConstantExpression(space.zero)
|
||||
|
||||
override fun const(value: T): AsmExpression<T> =
|
||||
AsmConstantExpression(value)
|
||||
|
||||
override fun variable(name: String, default: T?): AsmExpression<T> =
|
||||
AsmVariableExpression(
|
||||
name,
|
||||
default
|
||||
)
|
||||
|
||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmSumExpression(a, b)
|
||||
|
||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> =
|
||||
AsmConstProductExpression(a, k)
|
||||
|
||||
|
||||
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
|
||||
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
|
||||
|
||||
operator fun T.plus(arg: AsmExpression<T>) = arg + this
|
||||
operator fun T.minus(arg: AsmExpression<T>) = arg - this
|
||||
}
|
||||
|
||||
class AsmFunctionalExpressionField<T>(val field: Field<T>) : ExpressionField<T, AsmExpression<T>>,
|
||||
AsmFunctionalExpressionSpace<T>(field, field.one) {
|
||||
override val one: AsmExpression<T>
|
||||
get() = const(this.field.one)
|
||||
|
||||
override fun const(value: Double): AsmExpression<T> = const(field.run { one * value })
|
||||
|
||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmProductExpression(a, b)
|
||||
|
||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||
AsmDivExpression(a, b)
|
||||
|
||||
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
|
||||
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
|
||||
|
||||
operator fun T.times(arg: AsmExpression<T>) = arg * this
|
||||
operator fun T.div(arg: AsmExpression<T>) = arg / this
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes
|
||||
|
||||
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
|
||||
when (value) {
|
||||
-1 -> visitInsn(Opcodes.ICONST_M1)
|
||||
0 -> visitInsn(Opcodes.ICONST_0)
|
||||
1 -> visitInsn(Opcodes.ICONST_1)
|
||||
2 -> visitInsn(Opcodes.ICONST_2)
|
||||
3 -> visitInsn(Opcodes.ICONST_3)
|
||||
4 -> visitInsn(Opcodes.ICONST_4)
|
||||
5 -> visitInsn(Opcodes.ICONST_5)
|
||||
else -> visitLdcInsn(value)
|
||||
}
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package scientifik.kmath.expressions
|
||||
|
||||
import scientifik.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class AsmTest {
|
||||
@Test
|
||||
fun test() {
|
||||
val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x"))
|
||||
|
||||
val gen = AsmGenerationContext(
|
||||
java.lang.Double::class.java,
|
||||
RealField,
|
||||
"MyAsmCompiled"
|
||||
)
|
||||
|
||||
expr.invoke(gen)
|
||||
val compiled = gen.generate()
|
||||
val value = compiled.evaluate(mapOf("x" to 25.0))
|
||||
assertEquals(26.0, value)
|
||||
}
|
||||
}
|
@ -8,4 +8,4 @@ kotlin.sourceSets {
|
||||
api(project(":kmath-memory"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
|
||||
/**
|
||||
* A context for expression construction
|
||||
*/
|
||||
interface ExpressionContext<T, E : Expression<T>> {
|
||||
interface ExpressionContext<T, E> {
|
||||
/**
|
||||
* Introduce a variable into expression context
|
||||
*/
|
||||
@ -29,11 +29,11 @@ interface ExpressionContext<T, E : Expression<T>> {
|
||||
fun produce(node: SyntaxTreeNode): E
|
||||
}
|
||||
|
||||
interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T, E> {
|
||||
interface ExpressionSpace<T, E> : Space<E>, ExpressionContext<T, E> {
|
||||
|
||||
open fun produceSingular(value: String): E = variable(value)
|
||||
fun produceSingular(value: String): E = variable(value)
|
||||
|
||||
open fun produceUnary(operation: String, value: E): E {
|
||||
fun produceUnary(operation: String, value: E): E {
|
||||
return when (operation) {
|
||||
UnaryNode.PLUS_OPERATION -> value
|
||||
UnaryNode.MINUS_OPERATION -> -value
|
||||
@ -41,7 +41,7 @@ interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T,
|
||||
}
|
||||
}
|
||||
|
||||
open fun produceBinary(operation: String, left: E, right: E): E {
|
||||
fun produceBinary(operation: String, left: E, right: E): E {
|
||||
return when (operation) {
|
||||
BinaryNode.PLUS_OPERATION -> left + right
|
||||
BinaryNode.MINUS_OPERATION -> left - right
|
||||
@ -75,7 +75,7 @@ interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T,
|
||||
}
|
||||
}
|
||||
|
||||
interface ExpressionField<T, E : Expression<T>> : Field<E>, ExpressionSpace<T, E> {
|
||||
interface ExpressionField<T, E> : Field<E>, ExpressionSpace<T, E> {
|
||||
fun const(value: Double): E = one.times(value)
|
||||
|
||||
override fun produce(node: SyntaxTreeNode): E {
|
||||
|
@ -44,5 +44,6 @@ include(
|
||||
":kmath-dimensions",
|
||||
":kmath-for-real",
|
||||
":kmath-geometry",
|
||||
":examples"
|
||||
":examples",
|
||||
":kmath-asm"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user