Implement kmath-asm module stubs

This commit is contained in:
Iaroslav 2020-06-05 22:05:16 +07:00
parent 1a869ace0e
commit 3ea76d56a5
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
9 changed files with 443 additions and 8 deletions

View File

@ -0,0 +1,9 @@
plugins {
id("scientifik.jvm")
}
dependencies {
api(project(path = ":kmath-core"))
api("org.ow2.asm:asm:8.0.1")
api("org.ow2.asm:asm-commons:8.0.1")
}

View File

@ -0,0 +1,385 @@
package scientifik.kmath.expressions
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.*
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space
import java.io.File
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) {
abstract fun evaluate(arguments: Map<String, T>): T
}
class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T>, private val className: String) {
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
}
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
@Suppress("PrivatePropertyName")
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
@Suppress("PrivatePropertyName")
private val T_CLASS: String = classOfT.name.replace('.', '/')
private val constants: MutableList<T> = mutableListOf()
private val asmCompiledClassWriter = ClassWriter(0)
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
private val evaluateMethodVisitor: MethodVisitor
private val evaluateThisVar: Int = 0
private val evaluateArgumentsVar: Int = 1
private var evaluateL0: Label
private lateinit var evaluateL1: Label
var maxStack: Int = 0
init {
asmCompiledClassWriter.visit(
V1_8,
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
slashesClassName,
"L$ASM_COMPILED_CLASS<L$T_CLASS;>;",
ASM_COMPILED_CLASS,
arrayOf()
)
asmCompiledClassWriter.run {
visitMethod(ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = Label()
visitLabel(l0)
visitVarInsn(ALOAD, thisVar)
visitVarInsn(ALOAD, algebraVar)
visitVarInsn(ALOAD, constantsVar)
visitMethodInsn(
INVOKESPECIAL,
ASM_COMPILED_CLASS,
"<init>",
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
false
)
val l1 = Label()
visitLabel(l1)
visitInsn(RETURN)
val l2 = Label()
visitLabel(l2)
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
visitLocalVariable(
"algebra",
"L$ALGEBRA_CLASS;",
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
l0,
l2,
algebraVar
)
visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS<L$T_CLASS;>;", l0, l2, constantsVar)
visitMaxs(3, 3)
visitEnd()
}
evaluateMethodVisitor = visitMethod(
ACC_PUBLIC or ACC_FINAL,
"evaluate",
"(L$MAP_CLASS;)L$T_CLASS;",
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
null
)
evaluateMethodVisitor.run {
visitCode()
evaluateL0 = Label()
visitLabel(evaluateL0)
}
}
}
@Suppress("UNCHECKED_CAST")
fun generate(): AsmCompiled<T> {
evaluateMethodVisitor.run {
visitInsn(ARETURN)
evaluateL1 = Label()
visitLabel(evaluateL1)
visitLocalVariable(
"this",
"L$slashesClassName;",
T_CLASS,
evaluateL0,
evaluateL1,
evaluateThisVar
)
visitLocalVariable(
"arguments",
"L$MAP_CLASS;",
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
evaluateL0,
evaluateL1,
evaluateArgumentsVar
)
visitMaxs(maxStack + 1, 2)
visitEnd()
}
asmCompiledClassWriter.visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"evaluate",
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
null,
null
).run {
val thisVar = 0
visitCode()
val l0 = Label()
visitLabel(l0)
visitVarInsn(ALOAD, 0)
visitVarInsn(ALOAD, 1)
visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "evaluate", "(L$MAP_CLASS;)L$T_CLASS;", false)
visitInsn(ARETURN)
val l1 = Label()
visitLabel(l1)
visitLocalVariable(
"this",
"L$slashesClassName;",
T_CLASS,
l0,
l1,
thisVar
)
visitMaxs(2, 2)
visitEnd()
}
asmCompiledClassWriter.visitEnd()
return classLoader
.defineClass(className, asmCompiledClassWriter.toByteArray())
.constructors
.first()
.newInstance(algebra, constants) as AsmCompiled<T>
}
fun visitLoadFromConstants(value: T) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
maxStack++
evaluateMethodVisitor.run {
visitLoadThis()
visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
visitLdcOrIConstInsn(idx)
visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
visitCastToT()
}
}
private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value)
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
maxStack++
visitVarInsn(ALOAD, evaluateArgumentsVar)
if (defaultValue != null) {
visitLdcInsn(name)
visitLoadFromConstants(defaultValue)
visitMethodInsn(
INVOKEINTERFACE,
MAP_CLASS,
"getOrDefault",
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
true
)
visitCastToT()
return
}
visitLdcInsn(name)
visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true)
visitCastToT()
}
fun visitLoadAlgebra() {
evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar)
evaluateMethodVisitor.visitFieldInsn(
GETFIELD,
ASM_COMPILED_CLASS, "algebra", "L$ALGEBRA_CLASS;"
)
evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS)
}
fun visitInvokeAlgebraOperation(owner: String, method: String, descriptor: String) {
maxStack++
evaluateMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true)
visitCastToT()
}
fun visitCastToT() {
evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
}
companion object {
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
const val LIST_CLASS = "java/util/List"
const val MAP_CLASS = "java/util/Map"
const val OBJECT_CLASS = "java/lang/Object"
const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
const val SPACE_CLASS = "scientifik/kmath/operations/Space"
const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
const val FIELD_CLASS = "scientifik/kmath/operations/Field"
const val STRING_CLASS = "java/lang/String"
}
}
interface AsmExpression<T> {
fun invoke(gen: AsmGenerationContext<T>)
}
internal class AsmVariableExpression<T>(val name: String, val default: T? = null) :
AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadFromVariables(name, default)
}
}
internal class AsmConstantExpression<T>(val value: T) :
AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadFromConstants(value)
}
}
internal class AsmSumExpression<T>(
val first: AsmExpression<T>,
val second: AsmExpression<T>
) : AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
first.invoke(gen)
second.invoke(gen)
gen.visitInvokeAlgebraOperation(
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
method = "add",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
internal class AsmProductExpression<T>(
val first: AsmExpression<T>,
val second: AsmExpression<T>
) : AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
first.invoke(gen)
second.invoke(gen)
gen.visitInvokeAlgebraOperation(
owner = AsmGenerationContext.SPACE_CLASS,
method = "times",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
internal class AsmConstProductExpression<T>(
val expr: AsmExpression<T>,
val const: Number
) : AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
expr.invoke(gen)
gen.visitNumberConstant(const)
gen.visitInvokeAlgebraOperation(
owner = AsmGenerationContext.SPACE_CLASS,
method = "multiply",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
internal class AsmDivExpression<T>(
val expr: AsmExpression<T>,
val second: AsmExpression<T>
) : AsmExpression<T> {
override fun invoke(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
expr.invoke(gen)
second.invoke(gen)
gen.visitInvokeAlgebraOperation(
owner = AsmGenerationContext.FIELD_CLASS,
method = "divide",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
open class AsmFunctionalExpressionSpace<T>(
val space: Space<T>,
one: T
) : Space<AsmExpression<T>>,
ExpressionSpace<T, AsmExpression<T>> {
override val zero: AsmExpression<T> =
AsmConstantExpression(space.zero)
override fun const(value: T): AsmExpression<T> =
AsmConstantExpression(value)
override fun variable(name: String, default: T?): AsmExpression<T> =
AsmVariableExpression(
name,
default
)
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmSumExpression(a, b)
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> =
AsmConstProductExpression(a, k)
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
operator fun T.plus(arg: AsmExpression<T>) = arg + this
operator fun T.minus(arg: AsmExpression<T>) = arg - this
}
class AsmFunctionalExpressionField<T>(val field: Field<T>) : ExpressionField<T, AsmExpression<T>>,
AsmFunctionalExpressionSpace<T>(field, field.one) {
override val one: AsmExpression<T>
get() = const(this.field.one)
override fun const(value: Double): AsmExpression<T> = const(field.run { one * value })
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmProductExpression(a, b)
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmDivExpression(a, b)
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
operator fun T.times(arg: AsmExpression<T>) = arg * this
operator fun T.div(arg: AsmExpression<T>) = arg / this
}

View File

@ -0,0 +1,17 @@
package scientifik.kmath.expressions
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
fun MethodVisitor.visitLdcOrIConstInsn(value: Int) {
when (value) {
-1 -> visitInsn(Opcodes.ICONST_M1)
0 -> visitInsn(Opcodes.ICONST_0)
1 -> visitInsn(Opcodes.ICONST_1)
2 -> visitInsn(Opcodes.ICONST_2)
3 -> visitInsn(Opcodes.ICONST_3)
4 -> visitInsn(Opcodes.ICONST_4)
5 -> visitInsn(Opcodes.ICONST_5)
else -> visitLdcInsn(value)
}
}

View File

@ -0,0 +1,23 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class AsmTest {
@Test
fun test() {
val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x"))
val gen = AsmGenerationContext(
java.lang.Double::class.java,
RealField,
"MyAsmCompiled"
)
expr.invoke(gen)
val compiled = gen.generate()
val value = compiled.evaluate(mapOf("x" to 25.0))
assertEquals(26.0, value)
}
}

View File

@ -8,4 +8,4 @@ kotlin.sourceSets {
api(project(":kmath-memory"))
}
}
}
}

View File

@ -15,7 +15,7 @@ operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = invoke
/**
* A context for expression construction
*/
interface ExpressionContext<T, E : Expression<T>> {
interface ExpressionContext<T, E> {
/**
* Introduce a variable into expression context
*/
@ -29,11 +29,11 @@ interface ExpressionContext<T, E : Expression<T>> {
fun produce(node: SyntaxTreeNode): E
}
interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T, E> {
interface ExpressionSpace<T, E> : Space<E>, ExpressionContext<T, E> {
open fun produceSingular(value: String): E = variable(value)
fun produceSingular(value: String): E = variable(value)
open fun produceUnary(operation: String, value: E): E {
fun produceUnary(operation: String, value: E): E {
return when (operation) {
UnaryNode.PLUS_OPERATION -> value
UnaryNode.MINUS_OPERATION -> -value
@ -41,7 +41,7 @@ interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T,
}
}
open fun produceBinary(operation: String, left: E, right: E): E {
fun produceBinary(operation: String, left: E, right: E): E {
return when (operation) {
BinaryNode.PLUS_OPERATION -> left + right
BinaryNode.MINUS_OPERATION -> left - right
@ -75,7 +75,7 @@ interface ExpressionSpace<T, E : Expression<T>> : Space<E>, ExpressionContext<T,
}
}
interface ExpressionField<T, E : Expression<T>> : Field<E>, ExpressionSpace<T, E> {
interface ExpressionField<T, E> : Field<E>, ExpressionSpace<T, E> {
fun const(value: Double): E = one.times(value)
override fun produce(node: SyntaxTreeNode): E {

View File

@ -44,5 +44,6 @@ include(
":kmath-dimensions",
":kmath-for-real",
":kmath-geometry",
":examples"
":examples",
":kmath-asm"
)