Merge pull request #94 from CommanderTvis/adv-expr

Optimization of expressions via ASM Bytecode Generation
This commit is contained in:
Alexander Nozik 2020-06-13 21:07:14 +03:00 committed by GitHub
commit 878d1379e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1010 additions and 129 deletions

View File

@ -2,7 +2,7 @@ plugins {
id("scientifik.mpp") id("scientifik.mpp")
} }
repositories{ repositories {
maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/hotkeytlt/maven")
} }
@ -14,13 +14,18 @@ kotlin.sourceSets {
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3") implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
} }
} }
jvmMain{
dependencies{ jvmMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1")
implementation(kotlin("reflect"))
} }
} }
jsMain{
dependencies{ jsMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3") implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
} }
} }

View File

@ -0,0 +1,56 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmGenerationContext
import scientifik.kmath.ast.MST
import scientifik.kmath.ast.evaluate
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.*
@PublishedApi
internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(expression, collision + 1)
}
@PublishedApi
internal inline fun <reified T> AsmNode<T>.compile(algebra: Algebra<T>): Expression<T> {
val ctx =
AsmGenerationContext(T::class.java, algebra, buildName(this))
compile(ctx)
return ctx.generate()
}
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
expressionAlgebra: E,
block: E.() -> AsmNode<T>
): Expression<T> = expressionAlgebra.block().compile(expressionAlgebra.algebra)
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
expressionAlgebra: E,
ast: MST
): Expression<T> = asm(expressionAlgebra) { evaluate(ast) }
inline fun <reified T, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmNode<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
AsmExpressionSpace(this).let { it.block().compile(it.algebra) }
inline fun <reified T, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
asmSpace { evaluate(ast) }
inline fun <reified T, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmNode<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
AsmExpressionRing(this).let { it.block().compile(it.algebra) }
inline fun <reified T, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
asmRing { evaluate(ast) }
inline fun <reified T, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmNode<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
AsmExpressionField(this).let { it.block().compile(it.algebra) }
inline fun <reified T, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
asmRing { evaluate(ast) }

View File

@ -0,0 +1,265 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmGenerationContext
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.optimize
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.*
/**
* A function declaration that could be compiled to [AsmGenerationContext].
*
* @param T the type the stored function returns.
*/
abstract class AsmNode<T> internal constructor() {
/**
* Tries to evaluate this function without its variables. This method is intended for optimization.
*
* @return `null` if the function depends on its variables, the value if the function is a constant.
*/
internal open fun tryEvaluate(): T? = null
/**
* Compiles this declaration.
*
* @param gen the target [AsmGenerationContext].
*/
@PublishedApi
internal abstract fun compile(gen: AsmGenerationContext<T>)
}
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmNode<T>) :
AsmNode<T>() {
private val expr: AsmNode<T> = expr.optimize()
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
override fun compile(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
if (!hasSpecific(context, name, 1))
gen.visitStringConstant(name)
expr.compile(gen)
if (gen.tryInvokeSpecific(context, name, 1))
return
gen.visitAlgebraOperation(
owner = AsmGenerationContext.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmGenerationContext.STRING_CLASS};" +
"L${AsmGenerationContext.OBJECT_CLASS};)" +
"L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
internal class AsmBinaryOperation<T>(
private val context: Algebra<T>,
private val name: String,
first: AsmNode<T>,
second: AsmNode<T>
) : AsmNode<T>() {
private val first: AsmNode<T> = first.optimize()
private val second: AsmNode<T> = second.optimize()
override fun tryEvaluate(): T? = context {
binaryOperation(
name,
first.tryEvaluate() ?: return@context null,
second.tryEvaluate() ?: return@context null
)
}
override fun compile(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
if (!hasSpecific(context, name, 2))
gen.visitStringConstant(name)
first.compile(gen)
second.compile(gen)
if (gen.tryInvokeSpecific(context, name, 2))
return
gen.visitAlgebraOperation(
owner = AsmGenerationContext.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmGenerationContext.STRING_CLASS};" +
"L${AsmGenerationContext.OBJECT_CLASS};" +
"L${AsmGenerationContext.OBJECT_CLASS};)" +
"L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) :
AsmNode<T>() {
override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromVariables(name, default)
}
internal class AsmConstantExpression<T>(private val value: T) :
AsmNode<T>() {
override fun tryEvaluate(): T = value
override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
}
internal class AsmConstProductExpression<T>(
private val context: Space<T>,
expr: AsmNode<T>,
private val const: Number
) : AsmNode<T>() {
private val expr: AsmNode<T> = expr.optimize()
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
override fun compile(gen: AsmGenerationContext<T>) {
gen.visitLoadAlgebra()
gen.visitNumberConstant(const)
expr.compile(gen)
gen.visitAlgebraOperation(
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" +
"L${AsmGenerationContext.NUMBER_CLASS};)" +
"L${AsmGenerationContext.OBJECT_CLASS};"
)
}
}
internal class AsmNumberExpression<T>(private val context: NumericAlgebra<T>, private val value: Number) :
AsmNode<T>() {
override fun tryEvaluate(): T? = context.number(value)
override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitNumberConstant(value)
}
internal abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/**
* A context class for [AsmNode] construction.
*/
interface AsmExpressionAlgebra<T, A : NumericAlgebra<T>> : NumericAlgebra<AsmNode<T>>,
ExpressionAlgebra<T, AsmNode<T>> {
/**
* The algebra to provide for AsmExpressions built.
*/
val algebra: A
/**
* Builds an AsmExpression to wrap a number.
*/
override fun number(value: Number): AsmNode<T> = AsmNumberExpression(algebra, value)
/**
* Builds an AsmExpression of constant expression which does not depend on arguments.
*/
override fun const(value: T): AsmNode<T> = AsmConstantExpression(value)
/**
* Builds an AsmExpression to access a variable.
*/
override fun variable(name: String, default: T?): AsmNode<T> = AsmVariableExpression(name, default)
/**
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
AsmBinaryOperation(algebra, operation, left, right)
/**
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
AsmUnaryOperation(algebra, operation, arg)
}
/**
* A context class for [AsmNode] construction for [Space] algebras.
*/
open class AsmExpressionSpace<T, A>(override val algebra: A) : AsmExpressionAlgebra<T, A>,
Space<AsmNode<T>> where A : Space<T>, A : NumericAlgebra<T> {
override val zero: AsmNode<T>
get() = const(algebra.zero)
/**
* Builds an AsmExpression of addition of two another expressions.
*/
override fun add(a: AsmNode<T>, b: AsmNode<T>): AsmNode<T> =
AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an AsmExpression of multiplication of expression by number.
*/
override fun multiply(a: AsmNode<T>, k: Number): AsmNode<T> = AsmConstProductExpression(algebra, a, k)
operator fun AsmNode<T>.plus(arg: T): AsmNode<T> = this + const(arg)
operator fun AsmNode<T>.minus(arg: T): AsmNode<T> = this - const(arg)
operator fun T.plus(arg: AsmNode<T>): AsmNode<T> = arg + this
operator fun T.minus(arg: AsmNode<T>): AsmNode<T> = arg - this
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
}
/**
* A context class for [AsmNode] construction for [Ring] algebras.
*/
open class AsmExpressionRing<T, A>(override val algebra: A) : AsmExpressionSpace<T, A>(algebra),
Ring<AsmNode<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: AsmNode<T>
get() = const(algebra.one)
/**
* Builds an AsmExpression of multiplication of two expressions.
*/
override fun multiply(a: AsmNode<T>, b: AsmNode<T>): AsmNode<T> =
AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun AsmNode<T>.times(arg: T): AsmNode<T> = this * const(arg)
operator fun T.times(arg: AsmNode<T>): AsmNode<T> = arg * this
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
super<AsmExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmNode<T> = super<AsmExpressionSpace>.number(value)
}
/**
* A context class for [AsmNode] construction for [Field] algebras.
*/
open class AsmExpressionField<T, A>(override val algebra: A) :
AsmExpressionRing<T, A>(algebra),
Field<AsmNode<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an AsmExpression of division an expression by another one.
*/
override fun divide(a: AsmNode<T>, b: AsmNode<T>): AsmNode<T> =
AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun AsmNode<T>.div(arg: T): AsmNode<T> = this / const(arg)
operator fun T.div(arg: AsmNode<T>): AsmNode<T> = arg / this
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
super<AsmExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
super<AsmExpressionRing>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmNode<T> = super<AsmExpressionRing>.number(value)
}

View File

@ -0,0 +1,319 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.FunctionalCompiledExpression
import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader
import scientifik.kmath.operations.Algebra
/**
* AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java
* expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new
* class.
*
* @param T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class.
*/
@PublishedApi
internal class AsmGenerationContext<T> @PublishedApi internal constructor(
private val classOfT: Class<*>,
private val algebra: Algebra<T>,
private val className: String
) {
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
}
private val classLoader: ClassLoader =
ClassLoader(javaClass.classLoader)
@Suppress("PrivatePropertyName")
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
@Suppress("PrivatePropertyName")
private val T_CLASS: String = classOfT.name.replace('.', '/')
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
private val invokeThisVar: Int = 0
private val invokeArgumentsVar: Int = 1
private var maxStack: Int = 0
private val constants: MutableList<Any> = mutableListOf()
private val asmCompiledClassWriter: ClassWriter = ClassWriter(0)
private val invokeMethodVisitor: MethodVisitor
private val invokeL0: Label
private lateinit var invokeL1: Label
private var generatedInstance: FunctionalCompiledExpression<T>? = null
init {
asmCompiledClassWriter.visit(
Opcodes.V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
slashesClassName,
"L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
arrayOf()
)
asmCompiledClassWriter.run {
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = Label()
visitLabel(l0)
visitVarInsn(Opcodes.ALOAD, thisVar)
visitVarInsn(Opcodes.ALOAD, algebraVar)
visitVarInsn(Opcodes.ALOAD, constantsVar)
visitMethodInsn(
Opcodes.INVOKESPECIAL,
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
"<init>",
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
false
)
val l1 = Label()
visitLabel(l1)
visitInsn(Opcodes.RETURN)
val l2 = Label()
visitLabel(l2)
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
visitLocalVariable(
"algebra",
"L$ALGEBRA_CLASS;",
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
l0,
l2,
algebraVar
)
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar)
visitMaxs(3, 3)
visitEnd()
}
invokeMethodVisitor = visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
"invoke",
"(L$MAP_CLASS;)L$T_CLASS;",
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
null
)
invokeMethodVisitor.run {
visitCode()
invokeL0 = Label()
visitLabel(invokeL0)
}
}
}
@PublishedApi
@Suppress("UNCHECKED_CAST")
internal fun generate(): FunctionalCompiledExpression<T> {
generatedInstance?.let { return it }
invokeMethodVisitor.run {
visitInsn(Opcodes.ARETURN)
invokeL1 = Label()
visitLabel(invokeL1)
visitLocalVariable(
"this",
"L$slashesClassName;",
T_CLASS,
invokeL0,
invokeL1,
invokeThisVar
)
visitLocalVariable(
"arguments",
"L$MAP_CLASS;",
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
invokeL0,
invokeL1,
invokeArgumentsVar
)
visitMaxs(maxStack + 1, 2)
visitEnd()
}
asmCompiledClassWriter.visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
"invoke",
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
null,
null
).run {
val thisVar = 0
visitCode()
val l0 = Label()
visitLabel(l0)
visitVarInsn(Opcodes.ALOAD, 0)
visitVarInsn(Opcodes.ALOAD, 1)
visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false)
visitInsn(Opcodes.ARETURN)
val l1 = Label()
visitLabel(l1)
visitLocalVariable(
"this",
"L$slashesClassName;",
T_CLASS,
l0,
l1,
thisVar
)
visitMaxs(2, 2)
visitEnd()
}
asmCompiledClassWriter.visitEnd()
val new = classLoader
.defineClass(className, asmCompiledClassWriter.toByteArray())
.constructors
.first()
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
generatedInstance = new
return new
}
internal fun visitLoadFromConstants(value: T) {
if (classOfT in INLINABLE_NUMBERS) {
visitNumberConstant(value as Number)
visitCastToT()
return
}
visitLoadAnyFromConstants(value as Any, T_CLASS)
}
private fun visitLoadAnyFromConstants(value: Any, type: String) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
maxStack++
invokeMethodVisitor.run {
visitLoadThis()
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;")
visitLdcOrIConstInsn(idx)
visitInsn(Opcodes.AALOAD)
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type)
}
}
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
internal fun visitNumberConstant(value: Number) {
maxStack++
val clazz = value.javaClass
val c = clazz.name.replace('.', '/')
val sigLetter = SIGNATURE_LETTERS[clazz]
if (sigLetter != null) {
when (value) {
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value)
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value)
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value)
else -> invokeMethodVisitor.visitLdcInsn(value)
}
invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
return
}
visitLoadAnyFromConstants(value, c)
}
internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
maxStack += 2
visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar)
if (defaultValue != null) {
visitLdcInsn(name)
visitLoadFromConstants(defaultValue)
visitMethodInsn(
Opcodes.INVOKEINTERFACE,
MAP_CLASS,
"getOrDefault",
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
true
)
visitCastToT()
return
}
visitLdcInsn(name)
visitMethodInsn(
Opcodes.INVOKEINTERFACE,
MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true
)
visitCastToT()
}
internal fun visitLoadAlgebra() {
maxStack++
invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
invokeMethodVisitor.visitFieldInsn(
Opcodes.GETFIELD,
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
)
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS)
}
internal fun visitAlgebraOperation(
owner: String,
method: String,
descriptor: String,
opcode: Int = Opcodes.INVOKEINTERFACE,
isInterface: Boolean = true
) {
maxStack++
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
visitCastToT()
}
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS)
internal fun visitStringConstant(string: String) {
invokeMethodVisitor.visitLdcInsn(string)
}
internal companion object {
private val SIGNATURE_LETTERS by lazy {
mapOf(
java.lang.Byte::class.java to "B",
java.lang.Short::class.java to "S",
java.lang.Integer::class.java to "I",
java.lang.Long::class.java to "J",
java.lang.Float::class.java to "F",
java.lang.Double::class.java to "D"
)
}
private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys }
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
"scientifik/kmath/asm/FunctionalCompiledExpression"
internal const val MAP_CLASS = "java/util/Map"
internal const val OBJECT_CLASS = "java/lang/Object"
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
internal const val STRING_CLASS = "java/lang/String"
internal const val NUMBER_CLASS = "java/lang/Number"
}
}

View File

@ -0,0 +1,28 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.*
internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) {
-1 -> visitInsn(ICONST_M1)
0 -> visitInsn(ICONST_0)
1 -> visitInsn(ICONST_1)
2 -> visitInsn(ICONST_2)
3 -> visitInsn(ICONST_3)
4 -> visitInsn(ICONST_4)
5 -> visitInsn(ICONST_5)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) {
0.0 -> visitInsn(DCONST_0)
1.0 -> visitInsn(DCONST_1)
else -> visitLdcInsn(value)
}
internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) {
0f -> visitInsn(FCONST_0)
1f -> visitInsn(FCONST_1)
2f -> visitInsn(FCONST_2)
else -> visitLdcInsn(value)
}

View File

@ -0,0 +1,49 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmConstantExpression
import scientifik.kmath.asm.AsmNode
import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name
context::class.java.methods.find { it.name == aName && it.parameters.size == arity }
?: return false
return true
}
internal fun <T> AsmGenerationContext<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
val aName = methodNameAdapters[name] ?: name
context::class.java.methods.find { it.name == aName && it.parameters.size == arity }
?: return false
val owner = context::class.java.name.replace('.', '/')
val sig = buildString {
append('(')
repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") }
append(')')
append("L${AsmGenerationContext.OBJECT_CLASS};")
}
visitAlgebraOperation(
owner = owner,
method = aName,
descriptor = sig,
opcode = Opcodes.INVOKEVIRTUAL,
isInterface = false
)
return true
}
@PublishedApi
internal fun <T> AsmNode<T>.optimize(): AsmNode<T> {
val a = tryEvaluate()
return if (a == null) this else AsmConstantExpression(a)
}

View File

@ -1,23 +0,0 @@
package scientifik.kmath.ast
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
//TODO stubs for asm generation
interface AsmExpression<T>
interface AsmExpressionAlgebra<T, A : Algebra<T>> : NumericAlgebra<AsmExpression<T>> {
val algebra: A
}
fun <T> AsmExpression<T>.compile(): Expression<T> = TODO()
//TODO add converter for functional expressions
inline fun <reified T : Any, A : Algebra<T>> A.asm(
block: AsmExpressionAlgebra<T, A>.() -> AsmExpression<T>
): Expression<T> = TODO()
inline fun <reified T : Any, A : Algebra<T>> A.asm(ast: MST): Expression<T> = asm { evaluate(ast) }

View File

@ -0,0 +1,18 @@
package scietifik.kmath.ast
import scientifik.kmath.asm.asmField
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test
class AsmTest {
@Test
fun parsedExpression() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.asmField(mst)()
assertEquals(Complex(10.0, 0.0), res)
}
}

View File

@ -1,17 +1,17 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test
internal class ParserTest{ internal class ParserTest {
@Test @Test
fun parsedExpression(){ fun parsedExpression() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.evaluate(mst) val res = ComplexField.evaluate(mst)
assertEquals(Complex(10.0,0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
} }

View File

@ -0,0 +1,75 @@
package scietifik.kmath.ast.asm
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.asmRing
import scientifik.kmath.asm.asmSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class TestAsmAlgebras {
@Test
fun space() {
val res = ByteRing.asmSpace {
binaryOperation(
"+",
unaryOperation(
"+",
3.toByte() - (2.toByte() + (multiply(
add(const(1), const(1)),
2
) + 1.toByte()) * 3.toByte() - 1.toByte())
),
number(1)
) + variable("x") + zero
}("x" to 2.toByte())
assertEquals(16, res)
}
@Test
fun ring() {
val res = ByteRing.asmRing {
binaryOperation(
"+",
unaryOperation(
"+",
(3.toByte() - (2.toByte() + (multiply(
add(const(1), const(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * const(2)
}()
assertEquals(24, res)
}
@Test
fun field() {
val res = RealField.asmField {
divide(binaryOperation(
"+",
unaryOperation(
"+",
(3.0 - (2.0 + (multiply(
add(const(1.0), const(1.0)),
2
) + 1.0))) * 3 - 1.0
),
number(1)
) / 2, const(2.0)) * one
}()
assertEquals(3.0, res)
}
}

View File

@ -0,0 +1,21 @@
package scietifik.kmath.ast.asm
import scientifik.kmath.asm.asmField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class TestAsmExpressions {
@Test
fun testUnaryOperationInvocation() {
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0)
assertEquals(2.0, res)
}
@Test
fun testConstProductInvocation() {
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0)
assertEquals(4.0, res)
}
}

View File

@ -8,4 +8,4 @@ kotlin.sourceSets {
api(project(":kmath-memory")) api(project(":kmath-memory"))
} }
} }
} }

View File

@ -0,0 +1,23 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
/**
* Create a functional expression on this [Space]
*/
fun <T> Space<T>.buildExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
FunctionalExpressionSpace(this).run(block)
/**
* Create a functional expression on this [Ring]
*/
fun <T> Ring<T>.buildExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
FunctionalExpressionRing(this).run(block)
/**
* Create a functional expression on this [Field]
*/
fun <T> Field<T>.buildExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).run(block)

View File

@ -0,0 +1,139 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.*
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
}
internal class FunctionalBinaryOperation<T>(
val context: Algebra<T>,
val name: String,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
}
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value
}
internal class FunctionalConstProductExpression<T>(
val context: Space<T>,
private val expr: Expression<T>,
val const: Number
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
/**
* A context class for [Expression] construction.
*/
interface FunctionalExpressionAlgebra<T, A : Algebra<T>> : ExpressionAlgebra<T, Expression<T>> {
/**
* The algebra to provide for Expressions built.
*/
val algebra: A
/**
* Builds an Expression of constant expression which does not depend on arguments.
*/
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
/**
* Builds an Expression to access a variable.
*/
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
/**
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, operation, left, right)
/**
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
FunctionalUnaryOperation(algebra, operation, arg)
}
/**
* A context class for [Expression] construction for [Space] algebras.
*/
open class FunctionalExpressionSpace<T, A>(override val algebra: A) : FunctionalExpressionAlgebra<T, A>,
Space<Expression<T>> where A : Space<T> {
override val zero: Expression<T>
get() = const(algebra.zero)
/**
* Builds an Expression of addition of two another expressions.
*/
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an Expression of multiplication of expression by number.
*/
override fun multiply(a: Expression<T>, k: Number): Expression<T> =
FunctionalConstProductExpression(algebra, a, k)
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
}
open class FunctionalExpressionRing<T, A>(override val algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: Expression<T>
get() = const(algebra.one)
/**
* Builds an Expression of multiplication of two expressions.
*/
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
}
open class FunctionalExpressionField<T, A>(override val algebra: A) :
FunctionalExpressionRing<T, A>(algebra),
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an Expression of division an expression by another one.
*/
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
}

View File

@ -1,94 +0,0 @@
package scientifik.kmath.expressions
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class ConstantExpression<T>(val value: T) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = value
}
internal class SumExpression<T>(
val context: Space<T>,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
}
internal class ProductExpression<T>(
val context: Ring<T>,
val first: Expression<T>,
val second: Expression<T>
) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T =
context.multiply(first.invoke(arguments), second.invoke(arguments))
}
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
}
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
Expression<T> {
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
}
open class FunctionalExpressionSpace<T>(
val space: Space<T>
) : Space<Expression<T>>, ExpressionAlgebra<T, Expression<T>> {
override val zero: Expression<T> = ConstantExpression(space.zero)
override fun const(value: T): Expression<T> = ConstantExpression(value)
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
operator fun Expression<T>.plus(arg: T) = this + const(arg)
operator fun Expression<T>.minus(arg: T) = this - const(arg)
operator fun T.plus(arg: Expression<T>) = arg + this
operator fun T.minus(arg: Expression<T>) = arg - this
}
open class FunctionalExpressionField<T>(
val field: Field<T>
) : Field<Expression<T>>, ExpressionAlgebra<T, Expression<T>>, FunctionalExpressionSpace<T>(field) {
override val one: Expression<T> = ConstantExpression(this.field.one)
fun const(value: Double): Expression<T> = const(field.run { one * value })
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
operator fun Expression<T>.times(arg: T) = this * const(arg)
operator fun Expression<T>.div(arg: T) = this / const(arg)
operator fun T.times(arg: Expression<T>) = arg * this
operator fun T.div(arg: Expression<T>) = arg / this
}
/**
* Create a functional expression on this [Space]
*/
fun <T> Space<T>.buildExpression(block: FunctionalExpressionSpace<T>.() -> Expression<T>): Expression<T> =
FunctionalExpressionSpace(this).run(block)
/**
* Create a functional expression on this [Field]
*/
fun <T> Field<T>.buildExpression(block: FunctionalExpressionField<T>.() -> Expression<T>): Expression<T> =
FunctionalExpressionField(this).run(block)