Merge pull request #94 from CommanderTvis/adv-expr
Optimization of expressions via ASM Bytecode Generation
This commit is contained in:
commit
878d1379e1
@ -2,7 +2,7 @@ plugins {
|
|||||||
id("scientifik.mpp")
|
id("scientifik.mpp")
|
||||||
}
|
}
|
||||||
|
|
||||||
repositories{
|
repositories {
|
||||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -14,13 +14,18 @@ kotlin.sourceSets {
|
|||||||
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
|
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jvmMain{
|
|
||||||
dependencies{
|
jvmMain {
|
||||||
|
dependencies {
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
|
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
|
||||||
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
|
implementation(kotlin("reflect"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jsMain{
|
|
||||||
dependencies{
|
jsMain {
|
||||||
|
dependencies {
|
||||||
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
|
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.internal.AsmGenerationContext
|
||||||
|
import scientifik.kmath.ast.MST
|
||||||
|
import scientifik.kmath.ast.evaluate
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String {
|
||||||
|
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
|
||||||
|
|
||||||
|
try {
|
||||||
|
Class.forName(name)
|
||||||
|
} catch (ignored: ClassNotFoundException) {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildName(expression, collision + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal inline fun <reified T> AsmNode<T>.compile(algebra: Algebra<T>): Expression<T> {
|
||||||
|
val ctx =
|
||||||
|
AsmGenerationContext(T::class.java, algebra, buildName(this))
|
||||||
|
compile(ctx)
|
||||||
|
return ctx.generate()
|
||||||
|
}
|
||||||
|
|
||||||
|
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
|
||||||
|
expressionAlgebra: E,
|
||||||
|
block: E.() -> AsmNode<T>
|
||||||
|
): Expression<T> = expressionAlgebra.block().compile(expressionAlgebra.algebra)
|
||||||
|
|
||||||
|
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
|
||||||
|
expressionAlgebra: E,
|
||||||
|
ast: MST
|
||||||
|
): Expression<T> = asm(expressionAlgebra) { evaluate(ast) }
|
||||||
|
|
||||||
|
inline fun <reified T, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmNode<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
|
||||||
|
AsmExpressionSpace(this).let { it.block().compile(it.algebra) }
|
||||||
|
|
||||||
|
inline fun <reified T, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
|
||||||
|
asmSpace { evaluate(ast) }
|
||||||
|
|
||||||
|
inline fun <reified T, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmNode<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
|
||||||
|
AsmExpressionRing(this).let { it.block().compile(it.algebra) }
|
||||||
|
|
||||||
|
inline fun <reified T, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
|
||||||
|
asmRing { evaluate(ast) }
|
||||||
|
|
||||||
|
inline fun <reified T, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmNode<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
|
||||||
|
AsmExpressionField(this).let { it.block().compile(it.algebra) }
|
||||||
|
|
||||||
|
inline fun <reified T, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
|
||||||
|
asmRing { evaluate(ast) }
|
@ -0,0 +1,265 @@
|
|||||||
|
package scientifik.kmath.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.internal.AsmGenerationContext
|
||||||
|
import scientifik.kmath.asm.internal.hasSpecific
|
||||||
|
import scientifik.kmath.asm.internal.optimize
|
||||||
|
import scientifik.kmath.asm.internal.tryInvokeSpecific
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.expressions.ExpressionAlgebra
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A function declaration that could be compiled to [AsmGenerationContext].
|
||||||
|
*
|
||||||
|
* @param T the type the stored function returns.
|
||||||
|
*/
|
||||||
|
abstract class AsmNode<T> internal constructor() {
|
||||||
|
/**
|
||||||
|
* Tries to evaluate this function without its variables. This method is intended for optimization.
|
||||||
|
*
|
||||||
|
* @return `null` if the function depends on its variables, the value if the function is a constant.
|
||||||
|
*/
|
||||||
|
internal open fun tryEvaluate(): T? = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compiles this declaration.
|
||||||
|
*
|
||||||
|
* @param gen the target [AsmGenerationContext].
|
||||||
|
*/
|
||||||
|
@PublishedApi
|
||||||
|
internal abstract fun compile(gen: AsmGenerationContext<T>)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmNode<T>) :
|
||||||
|
AsmNode<T>() {
|
||||||
|
private val expr: AsmNode<T> = expr.optimize()
|
||||||
|
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
|
||||||
|
|
||||||
|
override fun compile(gen: AsmGenerationContext<T>) {
|
||||||
|
gen.visitLoadAlgebra()
|
||||||
|
|
||||||
|
if (!hasSpecific(context, name, 1))
|
||||||
|
gen.visitStringConstant(name)
|
||||||
|
|
||||||
|
expr.compile(gen)
|
||||||
|
|
||||||
|
if (gen.tryInvokeSpecific(context, name, 1))
|
||||||
|
return
|
||||||
|
|
||||||
|
gen.visitAlgebraOperation(
|
||||||
|
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
||||||
|
method = "unaryOperation",
|
||||||
|
descriptor = "(L${AsmGenerationContext.STRING_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};)" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmBinaryOperation<T>(
|
||||||
|
private val context: Algebra<T>,
|
||||||
|
private val name: String,
|
||||||
|
first: AsmNode<T>,
|
||||||
|
second: AsmNode<T>
|
||||||
|
) : AsmNode<T>() {
|
||||||
|
private val first: AsmNode<T> = first.optimize()
|
||||||
|
private val second: AsmNode<T> = second.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = context {
|
||||||
|
binaryOperation(
|
||||||
|
name,
|
||||||
|
first.tryEvaluate() ?: return@context null,
|
||||||
|
second.tryEvaluate() ?: return@context null
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun compile(gen: AsmGenerationContext<T>) {
|
||||||
|
gen.visitLoadAlgebra()
|
||||||
|
|
||||||
|
if (!hasSpecific(context, name, 2))
|
||||||
|
gen.visitStringConstant(name)
|
||||||
|
|
||||||
|
first.compile(gen)
|
||||||
|
second.compile(gen)
|
||||||
|
|
||||||
|
if (gen.tryInvokeSpecific(context, name, 2))
|
||||||
|
return
|
||||||
|
|
||||||
|
gen.visitAlgebraOperation(
|
||||||
|
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
||||||
|
method = "binaryOperation",
|
||||||
|
descriptor = "(L${AsmGenerationContext.STRING_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};)" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) :
|
||||||
|
AsmNode<T>() {
|
||||||
|
override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromVariables(name, default)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmConstantExpression<T>(private val value: T) :
|
||||||
|
AsmNode<T>() {
|
||||||
|
override fun tryEvaluate(): T = value
|
||||||
|
override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmConstProductExpression<T>(
|
||||||
|
private val context: Space<T>,
|
||||||
|
expr: AsmNode<T>,
|
||||||
|
private val const: Number
|
||||||
|
) : AsmNode<T>() {
|
||||||
|
private val expr: AsmNode<T> = expr.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
|
||||||
|
|
||||||
|
override fun compile(gen: AsmGenerationContext<T>) {
|
||||||
|
gen.visitLoadAlgebra()
|
||||||
|
gen.visitNumberConstant(const)
|
||||||
|
expr.compile(gen)
|
||||||
|
|
||||||
|
gen.visitAlgebraOperation(
|
||||||
|
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
||||||
|
method = "multiply",
|
||||||
|
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.NUMBER_CLASS};)" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmNumberExpression<T>(private val context: NumericAlgebra<T>, private val value: Number) :
|
||||||
|
AsmNode<T>() {
|
||||||
|
override fun tryEvaluate(): T? = context.number(value)
|
||||||
|
|
||||||
|
override fun compile(gen: AsmGenerationContext<T>): Unit = gen.visitNumberConstant(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal abstract class FunctionalCompiledExpression<T> internal constructor(
|
||||||
|
@JvmField protected val algebra: Algebra<T>,
|
||||||
|
@JvmField protected val constants: Array<Any>
|
||||||
|
) : Expression<T> {
|
||||||
|
abstract override fun invoke(arguments: Map<String, T>): T
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [AsmNode] construction.
|
||||||
|
*/
|
||||||
|
interface AsmExpressionAlgebra<T, A : NumericAlgebra<T>> : NumericAlgebra<AsmNode<T>>,
|
||||||
|
ExpressionAlgebra<T, AsmNode<T>> {
|
||||||
|
/**
|
||||||
|
* The algebra to provide for AsmExpressions built.
|
||||||
|
*/
|
||||||
|
val algebra: A
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression to wrap a number.
|
||||||
|
*/
|
||||||
|
override fun number(value: Number): AsmNode<T> = AsmNumberExpression(algebra, value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of constant expression which does not depend on arguments.
|
||||||
|
*/
|
||||||
|
override fun const(value: T): AsmNode<T> = AsmConstantExpression(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression to access a variable.
|
||||||
|
*/
|
||||||
|
override fun variable(name: String, default: T?): AsmNode<T> = AsmVariableExpression(name, default)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
|
*/
|
||||||
|
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
|
||||||
|
AsmBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
|
*/
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
|
||||||
|
AsmUnaryOperation(algebra, operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [AsmNode] construction for [Space] algebras.
|
||||||
|
*/
|
||||||
|
open class AsmExpressionSpace<T, A>(override val algebra: A) : AsmExpressionAlgebra<T, A>,
|
||||||
|
Space<AsmNode<T>> where A : Space<T>, A : NumericAlgebra<T> {
|
||||||
|
override val zero: AsmNode<T>
|
||||||
|
get() = const(algebra.zero)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of addition of two another expressions.
|
||||||
|
*/
|
||||||
|
override fun add(a: AsmNode<T>, b: AsmNode<T>): AsmNode<T> =
|
||||||
|
AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of multiplication of expression by number.
|
||||||
|
*/
|
||||||
|
override fun multiply(a: AsmNode<T>, k: Number): AsmNode<T> = AsmConstProductExpression(algebra, a, k)
|
||||||
|
|
||||||
|
operator fun AsmNode<T>.plus(arg: T): AsmNode<T> = this + const(arg)
|
||||||
|
operator fun AsmNode<T>.minus(arg: T): AsmNode<T> = this - const(arg)
|
||||||
|
operator fun T.plus(arg: AsmNode<T>): AsmNode<T> = arg + this
|
||||||
|
operator fun T.minus(arg: AsmNode<T>): AsmNode<T> = arg - this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
|
||||||
|
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
|
||||||
|
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [AsmNode] construction for [Ring] algebras.
|
||||||
|
*/
|
||||||
|
open class AsmExpressionRing<T, A>(override val algebra: A) : AsmExpressionSpace<T, A>(algebra),
|
||||||
|
Ring<AsmNode<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||||
|
override val one: AsmNode<T>
|
||||||
|
get() = const(algebra.one)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of multiplication of two expressions.
|
||||||
|
*/
|
||||||
|
override fun multiply(a: AsmNode<T>, b: AsmNode<T>): AsmNode<T> =
|
||||||
|
AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun AsmNode<T>.times(arg: T): AsmNode<T> = this * const(arg)
|
||||||
|
operator fun T.times(arg: AsmNode<T>): AsmNode<T> = arg * this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
|
||||||
|
super<AsmExpressionSpace>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
|
||||||
|
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun number(value: Number): AsmNode<T> = super<AsmExpressionSpace>.number(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [AsmNode] construction for [Field] algebras.
|
||||||
|
*/
|
||||||
|
open class AsmExpressionField<T, A>(override val algebra: A) :
|
||||||
|
AsmExpressionRing<T, A>(algebra),
|
||||||
|
Field<AsmNode<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
||||||
|
/**
|
||||||
|
* Builds an AsmExpression of division an expression by another one.
|
||||||
|
*/
|
||||||
|
override fun divide(a: AsmNode<T>, b: AsmNode<T>): AsmNode<T> =
|
||||||
|
AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun AsmNode<T>.div(arg: T): AsmNode<T> = this / const(arg)
|
||||||
|
operator fun T.div(arg: AsmNode<T>): AsmNode<T> = arg / this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmNode<T>): AsmNode<T> =
|
||||||
|
super<AsmExpressionRing>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmNode<T>, right: AsmNode<T>): AsmNode<T> =
|
||||||
|
super<AsmExpressionRing>.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
|
override fun number(value: Number): AsmNode<T> = super<AsmExpressionRing>.number(value)
|
||||||
|
}
|
@ -0,0 +1,319 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.ClassWriter
|
||||||
|
import org.objectweb.asm.Label
|
||||||
|
import org.objectweb.asm.MethodVisitor
|
||||||
|
import org.objectweb.asm.Opcodes
|
||||||
|
import scientifik.kmath.asm.FunctionalCompiledExpression
|
||||||
|
import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java
|
||||||
|
* expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new
|
||||||
|
* class.
|
||||||
|
*
|
||||||
|
* @param T the type of AsmExpression to unwrap.
|
||||||
|
* @param algebra the algebra the applied AsmExpressions use.
|
||||||
|
* @param className the unique class name of new loaded class.
|
||||||
|
*/
|
||||||
|
@PublishedApi
|
||||||
|
internal class AsmGenerationContext<T> @PublishedApi internal constructor(
|
||||||
|
private val classOfT: Class<*>,
|
||||||
|
private val algebra: Algebra<T>,
|
||||||
|
private val className: String
|
||||||
|
) {
|
||||||
|
private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) {
|
||||||
|
internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val classLoader: ClassLoader =
|
||||||
|
ClassLoader(javaClass.classLoader)
|
||||||
|
|
||||||
|
@Suppress("PrivatePropertyName")
|
||||||
|
private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/')
|
||||||
|
|
||||||
|
@Suppress("PrivatePropertyName")
|
||||||
|
private val T_CLASS: String = classOfT.name.replace('.', '/')
|
||||||
|
|
||||||
|
private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/')
|
||||||
|
private val invokeThisVar: Int = 0
|
||||||
|
private val invokeArgumentsVar: Int = 1
|
||||||
|
private var maxStack: Int = 0
|
||||||
|
private val constants: MutableList<Any> = mutableListOf()
|
||||||
|
private val asmCompiledClassWriter: ClassWriter = ClassWriter(0)
|
||||||
|
private val invokeMethodVisitor: MethodVisitor
|
||||||
|
private val invokeL0: Label
|
||||||
|
private lateinit var invokeL1: Label
|
||||||
|
private var generatedInstance: FunctionalCompiledExpression<T>? = null
|
||||||
|
|
||||||
|
init {
|
||||||
|
asmCompiledClassWriter.visit(
|
||||||
|
Opcodes.V1_8,
|
||||||
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
||||||
|
slashesClassName,
|
||||||
|
"L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
||||||
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||||
|
arrayOf()
|
||||||
|
)
|
||||||
|
|
||||||
|
asmCompiledClassWriter.run {
|
||||||
|
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run {
|
||||||
|
val thisVar = 0
|
||||||
|
val algebraVar = 1
|
||||||
|
val constantsVar = 2
|
||||||
|
val l0 = Label()
|
||||||
|
visitLabel(l0)
|
||||||
|
visitVarInsn(Opcodes.ALOAD, thisVar)
|
||||||
|
visitVarInsn(Opcodes.ALOAD, algebraVar)
|
||||||
|
visitVarInsn(Opcodes.ALOAD, constantsVar)
|
||||||
|
|
||||||
|
visitMethodInsn(
|
||||||
|
Opcodes.INVOKESPECIAL,
|
||||||
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||||
|
"<init>",
|
||||||
|
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
|
||||||
|
false
|
||||||
|
)
|
||||||
|
|
||||||
|
val l1 = Label()
|
||||||
|
visitLabel(l1)
|
||||||
|
visitInsn(Opcodes.RETURN)
|
||||||
|
val l2 = Label()
|
||||||
|
visitLabel(l2)
|
||||||
|
visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"algebra",
|
||||||
|
"L$ALGEBRA_CLASS;",
|
||||||
|
"L$ALGEBRA_CLASS<L$T_CLASS;>;",
|
||||||
|
l0,
|
||||||
|
l2,
|
||||||
|
algebraVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar)
|
||||||
|
visitMaxs(3, 3)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeMethodVisitor = visitMethod(
|
||||||
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
||||||
|
"invoke",
|
||||||
|
"(L$MAP_CLASS;)L$T_CLASS;",
|
||||||
|
"(L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;)L$T_CLASS;",
|
||||||
|
null
|
||||||
|
)
|
||||||
|
|
||||||
|
invokeMethodVisitor.run {
|
||||||
|
visitCode()
|
||||||
|
invokeL0 = Label()
|
||||||
|
visitLabel(invokeL0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
internal fun generate(): FunctionalCompiledExpression<T> {
|
||||||
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
|
invokeMethodVisitor.run {
|
||||||
|
visitInsn(Opcodes.ARETURN)
|
||||||
|
invokeL1 = Label()
|
||||||
|
visitLabel(invokeL1)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"this",
|
||||||
|
"L$slashesClassName;",
|
||||||
|
T_CLASS,
|
||||||
|
invokeL0,
|
||||||
|
invokeL1,
|
||||||
|
invokeThisVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"arguments",
|
||||||
|
"L$MAP_CLASS;",
|
||||||
|
"L$MAP_CLASS<L$STRING_CLASS;+L$T_CLASS;>;",
|
||||||
|
invokeL0,
|
||||||
|
invokeL1,
|
||||||
|
invokeArgumentsVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMaxs(maxStack + 1, 2)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
asmCompiledClassWriter.visitMethod(
|
||||||
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
||||||
|
"invoke",
|
||||||
|
"(L$MAP_CLASS;)L$OBJECT_CLASS;",
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
).run {
|
||||||
|
val thisVar = 0
|
||||||
|
visitCode()
|
||||||
|
val l0 = Label()
|
||||||
|
visitLabel(l0)
|
||||||
|
visitVarInsn(Opcodes.ALOAD, 0)
|
||||||
|
visitVarInsn(Opcodes.ALOAD, 1)
|
||||||
|
visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false)
|
||||||
|
visitInsn(Opcodes.ARETURN)
|
||||||
|
val l1 = Label()
|
||||||
|
visitLabel(l1)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"this",
|
||||||
|
"L$slashesClassName;",
|
||||||
|
T_CLASS,
|
||||||
|
l0,
|
||||||
|
l1,
|
||||||
|
thisVar
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMaxs(2, 2)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
asmCompiledClassWriter.visitEnd()
|
||||||
|
|
||||||
|
val new = classLoader
|
||||||
|
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||||
|
.constructors
|
||||||
|
.first()
|
||||||
|
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
|
||||||
|
|
||||||
|
generatedInstance = new
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun visitLoadFromConstants(value: T) {
|
||||||
|
if (classOfT in INLINABLE_NUMBERS) {
|
||||||
|
visitNumberConstant(value as Number)
|
||||||
|
visitCastToT()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
visitLoadAnyFromConstants(value as Any, T_CLASS)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun visitLoadAnyFromConstants(value: Any, type: String) {
|
||||||
|
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||||
|
maxStack++
|
||||||
|
|
||||||
|
invokeMethodVisitor.run {
|
||||||
|
visitLoadThis()
|
||||||
|
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;")
|
||||||
|
visitLdcOrIConstInsn(idx)
|
||||||
|
visitInsn(Opcodes.AALOAD)
|
||||||
|
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
|
||||||
|
|
||||||
|
internal fun visitNumberConstant(value: Number) {
|
||||||
|
maxStack++
|
||||||
|
val clazz = value.javaClass
|
||||||
|
val c = clazz.name.replace('.', '/')
|
||||||
|
val sigLetter = SIGNATURE_LETTERS[clazz]
|
||||||
|
|
||||||
|
if (sigLetter != null) {
|
||||||
|
when (value) {
|
||||||
|
is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value)
|
||||||
|
is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value)
|
||||||
|
is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value)
|
||||||
|
else -> invokeMethodVisitor.visitLdcInsn(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
visitLoadAnyFromConstants(value, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run {
|
||||||
|
maxStack += 2
|
||||||
|
visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar)
|
||||||
|
|
||||||
|
if (defaultValue != null) {
|
||||||
|
visitLdcInsn(name)
|
||||||
|
visitLoadFromConstants(defaultValue)
|
||||||
|
|
||||||
|
visitMethodInsn(
|
||||||
|
Opcodes.INVOKEINTERFACE,
|
||||||
|
MAP_CLASS,
|
||||||
|
"getOrDefault",
|
||||||
|
"(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;",
|
||||||
|
true
|
||||||
|
)
|
||||||
|
|
||||||
|
visitCastToT()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
visitLdcInsn(name)
|
||||||
|
visitMethodInsn(
|
||||||
|
Opcodes.INVOKEINTERFACE,
|
||||||
|
MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true
|
||||||
|
)
|
||||||
|
visitCastToT()
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun visitLoadAlgebra() {
|
||||||
|
maxStack++
|
||||||
|
invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar)
|
||||||
|
|
||||||
|
invokeMethodVisitor.visitFieldInsn(
|
||||||
|
Opcodes.GETFIELD,
|
||||||
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||||
|
)
|
||||||
|
|
||||||
|
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun visitAlgebraOperation(
|
||||||
|
owner: String,
|
||||||
|
method: String,
|
||||||
|
descriptor: String,
|
||||||
|
opcode: Int = Opcodes.INVOKEINTERFACE,
|
||||||
|
isInterface: Boolean = true
|
||||||
|
) {
|
||||||
|
maxStack++
|
||||||
|
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
|
||||||
|
visitCastToT()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS)
|
||||||
|
|
||||||
|
internal fun visitStringConstant(string: String) {
|
||||||
|
invokeMethodVisitor.visitLdcInsn(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal companion object {
|
||||||
|
private val SIGNATURE_LETTERS by lazy {
|
||||||
|
mapOf(
|
||||||
|
java.lang.Byte::class.java to "B",
|
||||||
|
java.lang.Short::class.java to "S",
|
||||||
|
java.lang.Integer::class.java to "I",
|
||||||
|
java.lang.Long::class.java to "J",
|
||||||
|
java.lang.Float::class.java to "F",
|
||||||
|
java.lang.Double::class.java to "D"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys }
|
||||||
|
|
||||||
|
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
||||||
|
"scientifik/kmath/asm/FunctionalCompiledExpression"
|
||||||
|
|
||||||
|
internal const val MAP_CLASS = "java/util/Map"
|
||||||
|
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||||
|
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||||
|
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||||
|
internal const val STRING_CLASS = "java/lang/String"
|
||||||
|
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.MethodVisitor
|
||||||
|
import org.objectweb.asm.Opcodes.*
|
||||||
|
|
||||||
|
internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) {
|
||||||
|
-1 -> visitInsn(ICONST_M1)
|
||||||
|
0 -> visitInsn(ICONST_0)
|
||||||
|
1 -> visitInsn(ICONST_1)
|
||||||
|
2 -> visitInsn(ICONST_2)
|
||||||
|
3 -> visitInsn(ICONST_3)
|
||||||
|
4 -> visitInsn(ICONST_4)
|
||||||
|
5 -> visitInsn(ICONST_5)
|
||||||
|
else -> visitLdcInsn(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) {
|
||||||
|
0.0 -> visitInsn(DCONST_0)
|
||||||
|
1.0 -> visitInsn(DCONST_1)
|
||||||
|
else -> visitLdcInsn(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) {
|
||||||
|
0f -> visitInsn(FCONST_0)
|
||||||
|
1f -> visitInsn(FCONST_1)
|
||||||
|
2f -> visitInsn(FCONST_2)
|
||||||
|
else -> visitLdcInsn(value)
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
|
import org.objectweb.asm.Opcodes
|
||||||
|
import scientifik.kmath.asm.AsmConstantExpression
|
||||||
|
import scientifik.kmath.asm.AsmNode
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
|
||||||
|
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
|
||||||
|
|
||||||
|
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
|
val aName = methodNameAdapters[name] ?: name
|
||||||
|
|
||||||
|
context::class.java.methods.find { it.name == aName && it.parameters.size == arity }
|
||||||
|
?: return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun <T> AsmGenerationContext<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
|
val aName = methodNameAdapters[name] ?: name
|
||||||
|
|
||||||
|
context::class.java.methods.find { it.name == aName && it.parameters.size == arity }
|
||||||
|
?: return false
|
||||||
|
|
||||||
|
val owner = context::class.java.name.replace('.', '/')
|
||||||
|
|
||||||
|
val sig = buildString {
|
||||||
|
append('(')
|
||||||
|
repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") }
|
||||||
|
append(')')
|
||||||
|
append("L${AsmGenerationContext.OBJECT_CLASS};")
|
||||||
|
}
|
||||||
|
|
||||||
|
visitAlgebraOperation(
|
||||||
|
owner = owner,
|
||||||
|
method = aName,
|
||||||
|
descriptor = sig,
|
||||||
|
opcode = Opcodes.INVOKEVIRTUAL,
|
||||||
|
isInterface = false
|
||||||
|
)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
@PublishedApi
|
||||||
|
internal fun <T> AsmNode<T>.optimize(): AsmNode<T> {
|
||||||
|
val a = tryEvaluate()
|
||||||
|
return if (a == null) this else AsmConstantExpression(a)
|
||||||
|
}
|
@ -1,23 +0,0 @@
|
|||||||
package scientifik.kmath.ast
|
|
||||||
|
|
||||||
import scientifik.kmath.expressions.Expression
|
|
||||||
import scientifik.kmath.operations.Algebra
|
|
||||||
import scientifik.kmath.operations.NumericAlgebra
|
|
||||||
|
|
||||||
//TODO stubs for asm generation
|
|
||||||
|
|
||||||
interface AsmExpression<T>
|
|
||||||
|
|
||||||
interface AsmExpressionAlgebra<T, A : Algebra<T>> : NumericAlgebra<AsmExpression<T>> {
|
|
||||||
val algebra: A
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T> AsmExpression<T>.compile(): Expression<T> = TODO()
|
|
||||||
|
|
||||||
//TODO add converter for functional expressions
|
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Algebra<T>> A.asm(
|
|
||||||
block: AsmExpressionAlgebra<T, A>.() -> AsmExpression<T>
|
|
||||||
): Expression<T> = TODO()
|
|
||||||
|
|
||||||
inline fun <reified T : Any, A : Algebra<T>> A.asm(ast: MST): Expression<T> = asm { evaluate(ast) }
|
|
18
kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt
Normal file
18
kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.asmField
|
||||||
|
import scientifik.kmath.ast.parseMath
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
class AsmTest {
|
||||||
|
@Test
|
||||||
|
fun parsedExpression() {
|
||||||
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
|
val res = ComplexField.asmField(mst)()
|
||||||
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
|
}
|
||||||
|
}
|
@ -1,17 +1,17 @@
|
|||||||
package scietifik.kmath.ast
|
package scietifik.kmath.ast
|
||||||
|
|
||||||
import org.junit.jupiter.api.Assertions.assertEquals
|
|
||||||
import org.junit.jupiter.api.Test
|
|
||||||
import scientifik.kmath.ast.evaluate
|
import scientifik.kmath.ast.evaluate
|
||||||
import scientifik.kmath.ast.parseMath
|
import scientifik.kmath.ast.parseMath
|
||||||
import scientifik.kmath.operations.Complex
|
import scientifik.kmath.operations.Complex
|
||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
internal class ParserTest{
|
internal class ParserTest {
|
||||||
@Test
|
@Test
|
||||||
fun parsedExpression(){
|
fun parsedExpression() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val mst = "2+2*(2+2)".parseMath()
|
||||||
val res = ComplexField.evaluate(mst)
|
val res = ComplexField.evaluate(mst)
|
||||||
assertEquals(Complex(10.0,0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
package scietifik.kmath.ast.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.asmField
|
||||||
|
import scientifik.kmath.asm.asmRing
|
||||||
|
import scientifik.kmath.asm.asmSpace
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.ByteRing
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class TestAsmAlgebras {
|
||||||
|
@Test
|
||||||
|
fun space() {
|
||||||
|
val res = ByteRing.asmSpace {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
3.toByte() - (2.toByte() + (multiply(
|
||||||
|
add(const(1), const(1)),
|
||||||
|
2
|
||||||
|
) + 1.toByte()) * 3.toByte() - 1.toByte())
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) + variable("x") + zero
|
||||||
|
}("x" to 2.toByte())
|
||||||
|
|
||||||
|
assertEquals(16, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun ring() {
|
||||||
|
val res = ByteRing.asmRing {
|
||||||
|
binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
(3.toByte() - (2.toByte() + (multiply(
|
||||||
|
add(const(1), const(1)),
|
||||||
|
2
|
||||||
|
) + 1.toByte()))) * 3.0 - 1.toByte()
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) * const(2)
|
||||||
|
}()
|
||||||
|
|
||||||
|
assertEquals(24, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun field() {
|
||||||
|
val res = RealField.asmField {
|
||||||
|
divide(binaryOperation(
|
||||||
|
"+",
|
||||||
|
|
||||||
|
unaryOperation(
|
||||||
|
"+",
|
||||||
|
(3.0 - (2.0 + (multiply(
|
||||||
|
add(const(1.0), const(1.0)),
|
||||||
|
2
|
||||||
|
) + 1.0))) * 3 - 1.0
|
||||||
|
),
|
||||||
|
|
||||||
|
number(1)
|
||||||
|
) / 2, const(2.0)) * one
|
||||||
|
}()
|
||||||
|
|
||||||
|
assertEquals(3.0, res)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,21 @@
|
|||||||
|
package scietifik.kmath.ast.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.asm.asmField
|
||||||
|
import scientifik.kmath.expressions.invoke
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class TestAsmExpressions {
|
||||||
|
@Test
|
||||||
|
fun testUnaryOperationInvocation() {
|
||||||
|
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0)
|
||||||
|
assertEquals(2.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testConstProductInvocation() {
|
||||||
|
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0)
|
||||||
|
assertEquals(4.0, res)
|
||||||
|
}
|
||||||
|
}
|
@ -8,4 +8,4 @@ kotlin.sourceSets {
|
|||||||
api(project(":kmath-memory"))
|
api(project(":kmath-memory"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a functional expression on this [Space]
|
||||||
|
*/
|
||||||
|
fun <T> Space<T>.buildExpression(block: FunctionalExpressionSpace<T, Space<T>>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionSpace(this).run(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a functional expression on this [Ring]
|
||||||
|
*/
|
||||||
|
fun <T> Ring<T>.buildExpression(block: FunctionalExpressionRing<T, Ring<T>>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionRing(this).run(block)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a functional expression on this [Field]
|
||||||
|
*/
|
||||||
|
fun <T> Field<T>.buildExpression(block: FunctionalExpressionField<T, Field<T>>.() -> Expression<T>): Expression<T> =
|
||||||
|
FunctionalExpressionField(this).run(block)
|
@ -0,0 +1,139 @@
|
|||||||
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, private val expr: Expression<T>) :
|
||||||
|
Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalBinaryOperation<T>(
|
||||||
|
val context: Algebra<T>,
|
||||||
|
val name: String,
|
||||||
|
val first: Expression<T>,
|
||||||
|
val second: Expression<T>
|
||||||
|
) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
|
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = value
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalConstProductExpression<T>(
|
||||||
|
val context: Space<T>,
|
||||||
|
private val expr: Expression<T>,
|
||||||
|
val const: Number
|
||||||
|
) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [Expression] construction.
|
||||||
|
*/
|
||||||
|
interface FunctionalExpressionAlgebra<T, A : Algebra<T>> : ExpressionAlgebra<T, Expression<T>> {
|
||||||
|
/**
|
||||||
|
* The algebra to provide for Expressions built.
|
||||||
|
*/
|
||||||
|
val algebra: A
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of constant expression which does not depend on arguments.
|
||||||
|
*/
|
||||||
|
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression to access a variable.
|
||||||
|
*/
|
||||||
|
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of dynamic call of binary operation [operation] on [left] and [right].
|
||||||
|
*/
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of dynamic call of unary operation with name [operation] on [arg].
|
||||||
|
*/
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A context class for [Expression] construction for [Space] algebras.
|
||||||
|
*/
|
||||||
|
open class FunctionalExpressionSpace<T, A>(override val algebra: A) : FunctionalExpressionAlgebra<T, A>,
|
||||||
|
Space<Expression<T>> where A : Space<T> {
|
||||||
|
override val zero: Expression<T>
|
||||||
|
get() = const(algebra.zero)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of addition of two another expressions.
|
||||||
|
*/
|
||||||
|
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of multiplication of expression by number.
|
||||||
|
*/
|
||||||
|
override fun multiply(a: Expression<T>, k: Number): Expression<T> =
|
||||||
|
FunctionalConstProductExpression(algebra, a, k)
|
||||||
|
|
||||||
|
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
|
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
|
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||||
|
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionAlgebra>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionAlgebra>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
open class FunctionalExpressionRing<T, A>(override val algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||||
|
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||||
|
override val one: Expression<T>
|
||||||
|
get() = const(algebra.one)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds an Expression of multiplication of two expressions.
|
||||||
|
*/
|
||||||
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
|
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionSpace>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionSpace>.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
open class FunctionalExpressionField<T, A>(override val algebra: A) :
|
||||||
|
FunctionalExpressionRing<T, A>(algebra),
|
||||||
|
Field<Expression<T>> where A : Field<T>, A : NumericAlgebra<T> {
|
||||||
|
/**
|
||||||
|
* Builds an Expression of division an expression by another one.
|
||||||
|
*/
|
||||||
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
|
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionRing>.unaryOperation(operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
super<FunctionalExpressionRing>.binaryOperation(operation, left, right)
|
||||||
|
}
|
@ -1,94 +0,0 @@
|
|||||||
package scientifik.kmath.expressions
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import scientifik.kmath.operations.Ring
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
|
|
||||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = value
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class SumExpression<T>(
|
|
||||||
val context: Space<T>,
|
|
||||||
val first: Expression<T>,
|
|
||||||
val second: Expression<T>
|
|
||||||
) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ProductExpression<T>(
|
|
||||||
val context: Ring<T>,
|
|
||||||
val first: Expression<T>,
|
|
||||||
val second: Expression<T>
|
|
||||||
) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
open class FunctionalExpressionSpace<T>(
|
|
||||||
val space: Space<T>
|
|
||||||
) : Space<Expression<T>>, ExpressionAlgebra<T, Expression<T>> {
|
|
||||||
|
|
||||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
|
||||||
|
|
||||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
|
||||||
|
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
|
||||||
|
|
||||||
|
|
||||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
|
||||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
|
||||||
|
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
|
||||||
}
|
|
||||||
|
|
||||||
open class FunctionalExpressionField<T>(
|
|
||||||
val field: Field<T>
|
|
||||||
) : Field<Expression<T>>, ExpressionAlgebra<T, Expression<T>>, FunctionalExpressionSpace<T>(field) {
|
|
||||||
|
|
||||||
override val one: Expression<T> = ConstantExpression(this.field.one)
|
|
||||||
|
|
||||||
fun const(value: Double): Expression<T> = const(field.run { one * value })
|
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
|
||||||
|
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
|
||||||
|
|
||||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
|
||||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
|
||||||
|
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a functional expression on this [Space]
|
|
||||||
*/
|
|
||||||
fun <T> Space<T>.buildExpression(block: FunctionalExpressionSpace<T>.() -> Expression<T>): Expression<T> =
|
|
||||||
FunctionalExpressionSpace(this).run(block)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a functional expression on this [Field]
|
|
||||||
*/
|
|
||||||
fun <T> Field<T>.buildExpression(block: FunctionalExpressionField<T>.() -> Expression<T>): Expression<T> =
|
|
||||||
FunctionalExpressionField(this).run(block)
|
|
Loading…
Reference in New Issue
Block a user