Completely rework Expression API to expose direct unaryOperation and binaryOperation, improve ASM API accordingly
This commit is contained in:
parent
33a519c10b
commit
1b6a0a13d8
@ -1,38 +0,0 @@
|
|||||||
package scientifik.kmath.expressions
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
import scientifik.kmath.operations.invoke
|
|
||||||
|
|
||||||
open class AsmExpressionSpace<T>(private val space: Space<T>) :
|
|
||||||
Space<AsmExpression<T>>,
|
|
||||||
ExpressionContext<T, AsmExpression<T>> {
|
|
||||||
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
|
||||||
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
|
||||||
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
|
||||||
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmSumExpression(space, a, b)
|
|
||||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(space, a, k)
|
|
||||||
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
|
|
||||||
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
|
|
||||||
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
|
|
||||||
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
|
|
||||||
}
|
|
||||||
|
|
||||||
class AsmExpressionField<T>(private val field: Field<T>) :
|
|
||||||
ExpressionContext<T, AsmExpression<T>>,
|
|
||||||
Field<AsmExpression<T>>,
|
|
||||||
AsmExpressionSpace<T>(field) {
|
|
||||||
override val one: AsmExpression<T>
|
|
||||||
get() = const(this.field.one)
|
|
||||||
|
|
||||||
fun number(value: Number): AsmExpression<T> = const(field { one * value })
|
|
||||||
|
|
||||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
|
||||||
AsmProductExpression(field, a, b)
|
|
||||||
|
|
||||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> = AsmDivExpression(field, a, b)
|
|
||||||
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
|
|
||||||
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
|
|
||||||
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
|
|
||||||
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
|
|
||||||
}
|
|
@ -1,123 +0,0 @@
|
|||||||
package scientifik.kmath.expressions
|
|
||||||
|
|
||||||
import scientifik.kmath.operations.*
|
|
||||||
|
|
||||||
abstract class AsmCompiledExpression<T> internal constructor(
|
|
||||||
@JvmField protected val algebra: Algebra<T>,
|
|
||||||
@JvmField protected val constants: MutableList<out Any>
|
|
||||||
) : Expression<T> {
|
|
||||||
abstract override fun invoke(arguments: Map<String, T>): T
|
|
||||||
}
|
|
||||||
|
|
||||||
interface AsmExpression<T> {
|
|
||||||
fun tryEvaluate(): T? = null
|
|
||||||
fun invoke(gen: AsmGenerationContext<T>)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class AsmVariableExpression<T>(val name: String, val default: T? = null) :
|
|
||||||
AsmExpression<T> {
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromVariables(name, default)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class AsmConstantExpression<T>(val value: T) :
|
|
||||||
AsmExpression<T> {
|
|
||||||
override fun tryEvaluate(): T = value
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class AsmSumExpression<T>(
|
|
||||||
private val algebra: SpaceOperations<T>,
|
|
||||||
first: AsmExpression<T>,
|
|
||||||
second: AsmExpression<T>
|
|
||||||
) : AsmExpression<T> {
|
|
||||||
private val first: AsmExpression<T> = first.optimize()
|
|
||||||
private val second: AsmExpression<T> = second.optimize()
|
|
||||||
|
|
||||||
override fun tryEvaluate(): T? = algebra {
|
|
||||||
(first.tryEvaluate() ?: return@algebra null) + (second.tryEvaluate() ?: return@algebra null)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
|
||||||
gen.visitLoadAlgebra()
|
|
||||||
first.invoke(gen)
|
|
||||||
second.invoke(gen)
|
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
|
||||||
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
|
||||||
method = "add",
|
|
||||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class AsmProductExpression<T>(
|
|
||||||
private val algebra: RingOperations<T>,
|
|
||||||
first: AsmExpression<T>,
|
|
||||||
second: AsmExpression<T>
|
|
||||||
) : AsmExpression<T> {
|
|
||||||
private val first: AsmExpression<T> = first.optimize()
|
|
||||||
private val second: AsmExpression<T> = second.optimize()
|
|
||||||
|
|
||||||
override fun tryEvaluate(): T? = algebra {
|
|
||||||
(first.tryEvaluate() ?: return@algebra null) * (second.tryEvaluate() ?: return@algebra null)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
|
||||||
gen.visitLoadAlgebra()
|
|
||||||
first.invoke(gen)
|
|
||||||
second.invoke(gen)
|
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
|
||||||
owner = AsmGenerationContext.RING_OPERATIONS_CLASS,
|
|
||||||
method = "multiply",
|
|
||||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class AsmConstProductExpression<T>(
|
|
||||||
private val algebra: SpaceOperations<T>,
|
|
||||||
expr: AsmExpression<T>,
|
|
||||||
private val const: Number
|
|
||||||
) : AsmExpression<T> {
|
|
||||||
private val expr: AsmExpression<T> = expr.optimize()
|
|
||||||
|
|
||||||
override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const }
|
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
|
||||||
gen.visitLoadAlgebra()
|
|
||||||
gen.visitNumberConstant(const)
|
|
||||||
expr.invoke(gen)
|
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
|
||||||
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
|
||||||
method = "multiply",
|
|
||||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class AsmDivExpression<T>(
|
|
||||||
private val algebra: FieldOperations<T>,
|
|
||||||
expr: AsmExpression<T>,
|
|
||||||
second: AsmExpression<T>
|
|
||||||
) : AsmExpression<T> {
|
|
||||||
private val expr: AsmExpression<T> = expr.optimize()
|
|
||||||
private val second: AsmExpression<T> = second.optimize()
|
|
||||||
|
|
||||||
override fun tryEvaluate(): T? = algebra {
|
|
||||||
(expr.tryEvaluate() ?: return@algebra null) / (second.tryEvaluate() ?: return@algebra null)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
|
||||||
gen.visitLoadAlgebra()
|
|
||||||
expr.invoke(gen)
|
|
||||||
second.invoke(gen)
|
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
|
||||||
owner = AsmGenerationContext.FIELD_OPERATIONS_CLASS,
|
|
||||||
method = "divide",
|
|
||||||
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
package scientifik.kmath.expressions
|
|
||||||
|
|
||||||
fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
|
|
||||||
val a = tryEvaluate()
|
|
||||||
return if (a == null) this else AsmConstantExpression(a)
|
|
||||||
}
|
|
@ -0,0 +1,74 @@
|
|||||||
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.ExpressionContext
|
||||||
|
import scientifik.kmath.operations.*
|
||||||
|
|
||||||
|
open class AsmExpressionAlgebra<T>(val algebra: Algebra<T>) :
|
||||||
|
Algebra<AsmExpression<T>>,
|
||||||
|
ExpressionContext<T, AsmExpression<T>> {
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
|
||||||
|
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
|
||||||
|
}
|
||||||
|
|
||||||
|
open class AsmExpressionSpace<T>(
|
||||||
|
val space: Space<T>
|
||||||
|
) : AsmExpressionAlgebra<T>(space), Space<AsmExpression<T>> {
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
override val zero: AsmExpression<T> = AsmConstantExpression(space.zero)
|
||||||
|
|
||||||
|
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(space, a, k)
|
||||||
|
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
|
||||||
|
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
|
||||||
|
operator fun T.plus(arg: AsmExpression<T>) = arg + this
|
||||||
|
operator fun T.minus(arg: AsmExpression<T>) = arg - this
|
||||||
|
}
|
||||||
|
|
||||||
|
open class AsmExpressionRing<T>(private val ring: Ring<T>) : AsmExpressionSpace<T>(ring), Ring<AsmExpression<T>> {
|
||||||
|
override val one: AsmExpression<T>
|
||||||
|
get() = const(this.ring.one)
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
fun number(value: Number): AsmExpression<T> = const(ring { one * value })
|
||||||
|
|
||||||
|
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
|
||||||
|
operator fun T.times(arg: AsmExpression<T>) = arg * this
|
||||||
|
}
|
||||||
|
|
||||||
|
open class AsmExpressionField<T>(private val field: Field<T>) :
|
||||||
|
AsmExpressionRing<T>(field),
|
||||||
|
Field<AsmExpression<T>> {
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
|
AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
|
||||||
|
operator fun T.div(arg: AsmExpression<T>) = arg / this
|
||||||
|
}
|
@ -0,0 +1,106 @@
|
|||||||
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
|
import scientifik.kmath.operations.Algebra
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
|
||||||
|
interface AsmExpression<T> {
|
||||||
|
fun tryEvaluate(): T? = null
|
||||||
|
fun invoke(gen: AsmGenerationContext<T>)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
|
||||||
|
AsmExpression<T> {
|
||||||
|
private val expr: AsmExpression<T> = expr.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = context {
|
||||||
|
unaryOperation(
|
||||||
|
name,
|
||||||
|
expr.tryEvaluate() ?: return@context null
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
|
gen.visitLoadAlgebra()
|
||||||
|
gen.visitStringConstant(name)
|
||||||
|
expr.invoke(gen)
|
||||||
|
|
||||||
|
gen.visitAlgebraOperation(
|
||||||
|
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
||||||
|
method = "unaryOperation",
|
||||||
|
descriptor = "(L${AsmGenerationContext.STRING_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};)" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmBinaryOperation<T>(
|
||||||
|
private val context: Algebra<T>,
|
||||||
|
private val name: String,
|
||||||
|
first: AsmExpression<T>,
|
||||||
|
second: AsmExpression<T>
|
||||||
|
) : AsmExpression<T> {
|
||||||
|
private val first: AsmExpression<T> = first.optimize()
|
||||||
|
private val second: AsmExpression<T> = second.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = context {
|
||||||
|
binaryOperation(
|
||||||
|
name,
|
||||||
|
first.tryEvaluate() ?: return@context null,
|
||||||
|
second.tryEvaluate() ?: return@context null
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
|
gen.visitLoadAlgebra()
|
||||||
|
gen.visitStringConstant(name)
|
||||||
|
first.invoke(gen)
|
||||||
|
second.invoke(gen)
|
||||||
|
|
||||||
|
gen.visitAlgebraOperation(
|
||||||
|
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
||||||
|
method = "binaryOperation",
|
||||||
|
descriptor = "(L${AsmGenerationContext.STRING_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};)" +
|
||||||
|
"L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) : AsmExpression<T> {
|
||||||
|
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromVariables(name, default)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmConstantExpression<T>(private val value: T) : AsmExpression<T> {
|
||||||
|
override fun tryEvaluate(): T = value
|
||||||
|
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class AsmConstProductExpression<T>(private val context: Space<T>, expr: AsmExpression<T>, private val const: Number) :
|
||||||
|
AsmExpression<T> {
|
||||||
|
private val expr: AsmExpression<T> = expr.optimize()
|
||||||
|
|
||||||
|
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
|
||||||
|
|
||||||
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
|
gen.visitLoadAlgebra()
|
||||||
|
gen.visitNumberConstant(const)
|
||||||
|
expr.invoke(gen)
|
||||||
|
|
||||||
|
gen.visitAlgebraOperation(
|
||||||
|
owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS,
|
||||||
|
method = "multiply",
|
||||||
|
descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal abstract class FunctionalCompiledExpression<T> internal constructor(
|
||||||
|
@JvmField protected val algebra: Algebra<T>,
|
||||||
|
@JvmField protected val constants: MutableList<out Any>
|
||||||
|
) : Expression<T> {
|
||||||
|
abstract override fun invoke(arguments: Map<String, T>): T
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
import org.objectweb.asm.ClassWriter
|
import org.objectweb.asm.ClassWriter
|
||||||
import org.objectweb.asm.Label
|
import org.objectweb.asm.Label
|
||||||
@ -33,15 +33,15 @@ class AsmGenerationContext<T>(
|
|||||||
private val invokeMethodVisitor: MethodVisitor
|
private val invokeMethodVisitor: MethodVisitor
|
||||||
private val invokeL0: Label
|
private val invokeL0: Label
|
||||||
private lateinit var invokeL1: Label
|
private lateinit var invokeL1: Label
|
||||||
private var generatedInstance: AsmCompiledExpression<T>? = null
|
private var generatedInstance: FunctionalCompiledExpression<T>? = null
|
||||||
|
|
||||||
init {
|
init {
|
||||||
asmCompiledClassWriter.visit(
|
asmCompiledClassWriter.visit(
|
||||||
Opcodes.V1_8,
|
Opcodes.V1_8,
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
||||||
slashesClassName,
|
slashesClassName,
|
||||||
"L$ASM_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
"L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS<L$T_CLASS;>;",
|
||||||
ASM_COMPILED_EXPRESSION_CLASS,
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||||
arrayOf()
|
arrayOf()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class AsmGenerationContext<T>(
|
|||||||
|
|
||||||
visitMethodInsn(
|
visitMethodInsn(
|
||||||
Opcodes.INVOKESPECIAL,
|
Opcodes.INVOKESPECIAL,
|
||||||
ASM_COMPILED_EXPRESSION_CLASS,
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||||
"<init>",
|
"<init>",
|
||||||
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
||||||
false
|
false
|
||||||
@ -103,7 +103,7 @@ class AsmGenerationContext<T>(
|
|||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
internal fun generate(): AsmCompiledExpression<T> {
|
internal fun generate(): FunctionalCompiledExpression<T> {
|
||||||
generatedInstance?.let { return it }
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
invokeMethodVisitor.run {
|
invokeMethodVisitor.run {
|
||||||
@ -170,7 +170,7 @@ class AsmGenerationContext<T>(
|
|||||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||||
.constructors
|
.constructors
|
||||||
.first()
|
.first()
|
||||||
.newInstance(algebra, constants) as AsmCompiledExpression<T>
|
.newInstance(algebra, constants) as FunctionalCompiledExpression<T>
|
||||||
|
|
||||||
generatedInstance = new
|
generatedInstance = new
|
||||||
return new
|
return new
|
||||||
@ -245,7 +245,7 @@ class AsmGenerationContext<T>(
|
|||||||
|
|
||||||
invokeMethodVisitor.visitFieldInsn(
|
invokeMethodVisitor.visitFieldInsn(
|
||||||
Opcodes.GETFIELD,
|
Opcodes.GETFIELD,
|
||||||
ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;"
|
||||||
)
|
)
|
||||||
|
|
||||||
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS)
|
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS)
|
||||||
@ -259,6 +259,10 @@ class AsmGenerationContext<T>(
|
|||||||
|
|
||||||
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS)
|
private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS)
|
||||||
|
|
||||||
|
internal fun visitStringConstant(string: String) {
|
||||||
|
invokeMethodVisitor.visitLdcInsn(string)
|
||||||
|
}
|
||||||
|
|
||||||
internal companion object {
|
internal companion object {
|
||||||
private val SIGNATURE_LETTERS = mapOf(
|
private val SIGNATURE_LETTERS = mapOf(
|
||||||
java.lang.Byte::class.java to "B",
|
java.lang.Byte::class.java to "B",
|
||||||
@ -269,15 +273,13 @@ class AsmGenerationContext<T>(
|
|||||||
java.lang.Double::class.java to "D"
|
java.lang.Double::class.java to "D"
|
||||||
)
|
)
|
||||||
|
|
||||||
internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression"
|
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/asm/FunctionalCompiledExpression"
|
||||||
internal const val LIST_CLASS = "java/util/List"
|
internal const val LIST_CLASS = "java/util/List"
|
||||||
internal const val MAP_CLASS = "java/util/Map"
|
internal const val MAP_CLASS = "java/util/Map"
|
||||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||||
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations"
|
||||||
internal const val STRING_CLASS = "java/lang/String"
|
internal const val STRING_CLASS = "java/lang/String"
|
||||||
internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations"
|
|
||||||
internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations"
|
|
||||||
internal const val NUMBER_CLASS = "java/lang/Number"
|
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,7 +1,9 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Ring
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
@ -25,11 +27,21 @@ inline fun <reified T, I> asm(i: I, algebra: Algebra<T>, block: I.() -> AsmExpre
|
|||||||
return ctx.generate()
|
return ctx.generate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline fun <reified T> asmAlgebra(
|
||||||
|
algebra: Algebra<T>,
|
||||||
|
block: AsmExpressionAlgebra<T>.() -> AsmExpression<T>
|
||||||
|
): Expression<T> = asm(AsmExpressionAlgebra(algebra), algebra, block)
|
||||||
|
|
||||||
inline fun <reified T> asmSpace(
|
inline fun <reified T> asmSpace(
|
||||||
algebra: Space<T>,
|
algebra: Space<T>,
|
||||||
block: AsmExpressionSpace<T>.() -> AsmExpression<T>
|
block: AsmExpressionSpace<T>.() -> AsmExpression<T>
|
||||||
): Expression<T> = asm(AsmExpressionSpace(algebra), algebra, block)
|
): Expression<T> = asm(AsmExpressionSpace(algebra), algebra, block)
|
||||||
|
|
||||||
|
inline fun <reified T> asmRing(
|
||||||
|
algebra: Ring<T>,
|
||||||
|
block: AsmExpressionRing<T>.() -> AsmExpression<T>
|
||||||
|
): Expression<T> = asm(AsmExpressionRing(algebra), algebra, block)
|
||||||
|
|
||||||
inline fun <reified T> asmField(
|
inline fun <reified T> asmField(
|
||||||
algebra: Field<T>,
|
algebra: Field<T>,
|
||||||
block: AsmExpressionField<T>.() -> AsmExpression<T>
|
block: AsmExpressionField<T>.() -> AsmExpression<T>
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
import org.objectweb.asm.MethodVisitor
|
import org.objectweb.asm.MethodVisitor
|
||||||
import org.objectweb.asm.Opcodes.*
|
import org.objectweb.asm.Opcodes.*
|
@ -0,0 +1,9 @@
|
|||||||
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
|
import scientifik.kmath.expressions.asm.AsmConstantExpression
|
||||||
|
import scientifik.kmath.expressions.asm.AsmExpression
|
||||||
|
|
||||||
|
fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||||
|
val a = tryEvaluate()
|
||||||
|
return if (a == null) this else AsmConstantExpression(a)
|
||||||
|
}
|
@ -1,103 +1,40 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.expressions.asm.AsmExpression
|
||||||
|
import scientifik.kmath.expressions.asm.AsmExpressionField
|
||||||
|
import scientifik.kmath.expressions.asm.asmField
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class AsmTest {
|
class AsmTest {
|
||||||
private fun <T> testExpressionValue(
|
private fun testDoubleExpression(
|
||||||
expectedValue: T,
|
expected: Double?,
|
||||||
expr: AsmExpression<T>,
|
arguments: Map<String, Double> = emptyMap(),
|
||||||
arguments: Map<String, T>,
|
block: AsmExpressionField<Double>.() -> AsmExpression<Double>
|
||||||
algebra: Algebra<T>,
|
): Unit = assertEquals(expected = expected, actual = asmField(RealField, block)(arguments))
|
||||||
clazz: Class<*>
|
|
||||||
): Unit = assertEquals(
|
|
||||||
expectedValue, AsmGenerationContext(clazz, algebra, "TestAsmCompiled")
|
|
||||||
.also(expr::invoke)
|
|
||||||
.generate()
|
|
||||||
.invoke(arguments)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
private fun testDoubleExpressionValue(
|
|
||||||
expectedValue: Double,
|
|
||||||
expr: AsmExpression<Double>,
|
|
||||||
arguments: Map<String, Double>,
|
|
||||||
algebra: Algebra<Double> = RealField,
|
|
||||||
clazz: Class<Double> = java.lang.Double::class.java as Class<Double>
|
|
||||||
): Unit = testExpressionValue(expectedValue, expr, arguments, algebra, clazz)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSum() = testDoubleExpressionValue(
|
fun testConstantsSum() = testDoubleExpression(16.0) { const(8.0) + 8.0 }
|
||||||
25.0,
|
|
||||||
AsmSumExpression(RealField, AsmConstantExpression(1.0), AsmVariableExpression("x")),
|
|
||||||
mapOf("x" to 24.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testConst(): Unit = testDoubleExpressionValue(
|
fun testVarsSum() = testDoubleExpression(1000.0, mapOf("x" to 500.0)) { variable("x") + 500.0 }
|
||||||
123.0,
|
|
||||||
AsmConstantExpression(123.0),
|
|
||||||
mapOf()
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDiv(): Unit = testDoubleExpressionValue(
|
fun testProduct() = testDoubleExpression(24.0) { const(4.0) * const(6.0) }
|
||||||
0.5,
|
|
||||||
AsmDivExpression(RealField, AsmConstantExpression(1.0), AsmConstantExpression(2.0)),
|
|
||||||
mapOf()
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testProduct(): Unit = testDoubleExpressionValue(
|
fun testConstantProduct() = testDoubleExpression(984.0) { const(8.0) * 123 }
|
||||||
25.0,
|
|
||||||
AsmProductExpression(RealField,AsmVariableExpression("x"), AsmVariableExpression("x")),
|
|
||||||
mapOf("x" to 5.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCProduct(): Unit = testDoubleExpressionValue(
|
fun testSubtraction() = testDoubleExpression(2.0) { const(4.0) - 2.0 }
|
||||||
25.0,
|
|
||||||
AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5.0),
|
|
||||||
mapOf("x" to 5.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue(
|
fun testDivision() = testDoubleExpression(64.0) { const(128.0) / 2 }
|
||||||
25.0,
|
|
||||||
AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5f),
|
|
||||||
mapOf("x" to 5.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
object CustomZero : Number() {
|
|
||||||
override fun toByte(): Byte = 0
|
|
||||||
override fun toChar(): Char = 0.toChar()
|
|
||||||
override fun toDouble(): Double = 0.0
|
|
||||||
override fun toFloat(): Float = 0f
|
|
||||||
override fun toInt(): Int = 0
|
|
||||||
override fun toLong(): Long = 0L
|
|
||||||
override fun toShort(): Short = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue(
|
fun testDirectCall() = testDoubleExpression(4096.0) { binaryOperation("*", const(64.0), const(64.0)) }
|
||||||
0.0,
|
|
||||||
AsmConstProductExpression(RealField,AsmVariableExpression("x"), CustomZero),
|
|
||||||
mapOf("x" to 5.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
// @Test
|
||||||
fun testVar(): Unit = testDoubleExpressionValue(
|
// fun testSine() = testDoubleExpression(0.0) { unaryOperation("sin", const(PI)) }
|
||||||
10000.0,
|
|
||||||
AsmVariableExpression("x"),
|
|
||||||
mapOf("x" to 10000.0)
|
|
||||||
)
|
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testVarWithDefault(): Unit = testDoubleExpressionValue(
|
|
||||||
10000.0,
|
|
||||||
AsmVariableExpression("x", 10000.0),
|
|
||||||
mapOf()
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
@ -1,80 +1,102 @@
|
|||||||
package scientifik.kmath.expressions
|
package scientifik.kmath.expressions
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.*
|
||||||
import scientifik.kmath.operations.Ring
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
|
|
||||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
internal class FunctionalUnaryOperation<T>(val context: Algebra<T>, val name: String, val expr: Expression<T>) :
|
||||||
|
Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T = context.unaryOperation(name, expr.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalBinaryOperation<T>(
|
||||||
|
val context: Algebra<T>,
|
||||||
|
val name: String,
|
||||||
|
val first: Expression<T>,
|
||||||
|
val second: Expression<T>
|
||||||
|
) : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
|
context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class FunctionalVariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
arguments[name] ?: default ?: error("Parameter not found: $name")
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
internal class FunctionalConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = value
|
override fun invoke(arguments: Map<String, T>): T = value
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class SumExpression<T>(
|
internal class FunctionalConstProductExpression<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
||||||
val context: Space<T>,
|
|
||||||
val first: Expression<T>,
|
|
||||||
val second: Expression<T>
|
|
||||||
) : Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.add(first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ProductExpression<T>(val context: Ring<T>, val first: Expression<T>, val second: Expression<T>) :
|
|
||||||
Expression<T> {
|
|
||||||
override fun invoke(arguments: Map<String, T>): T =
|
|
||||||
context.multiply(first.invoke(arguments), second.invoke(arguments))
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class ConstProductExpession<T>(val context: Space<T>, val expr: Expression<T>, val const: Number) :
|
|
||||||
Expression<T> {
|
Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
override fun invoke(arguments: Map<String, T>): T = context.multiply(expr.invoke(arguments), const)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class DivExpession<T>(val context: Field<T>, val expr: Expression<T>, val second: Expression<T>) :
|
open class FunctionalExpressionAlgebra<T>(val algebra: Algebra<T>) :
|
||||||
Expression<T> {
|
Algebra<Expression<T>>,
|
||||||
override fun invoke(arguments: Map<String, T>): T = context.divide(expr.invoke(arguments), second.invoke(arguments))
|
ExpressionContext<T, Expression<T>> {
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
override fun const(value: T): Expression<T> = FunctionalConstantExpression(value)
|
||||||
|
override fun variable(name: String, default: T?): Expression<T> = FunctionalVariableExpression(name, default)
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionSpace<T>(
|
open class FunctionalExpressionSpace<T>(val space: Space<T>) :
|
||||||
val space: Space<T>
|
FunctionalExpressionAlgebra<T>(space),
|
||||||
) : Space<Expression<T>>, ExpressionContext<T, Expression<T>> {
|
Space<Expression<T>> {
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
override val zero: Expression<T> = ConstantExpression(space.zero)
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
override fun const(value: T): Expression<T> = ConstantExpression(value)
|
override val zero: Expression<T> = FunctionalConstantExpression(space.zero)
|
||||||
|
|
||||||
override fun variable(name: String, default: T?): Expression<T> = VariableExpression(name, default)
|
|
||||||
|
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = SumExpression(space, a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = ConstProductExpession(space, a, k)
|
|
||||||
|
|
||||||
|
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
|
override fun multiply(a: Expression<T>, k: Number): Expression<T> = FunctionalConstProductExpression(space, a, k)
|
||||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
||||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
||||||
|
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
operator fun T.plus(arg: Expression<T>) = arg + this
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
operator fun T.minus(arg: Expression<T>) = arg - this
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionField<T>(
|
open class FunctionalExpressionRing<T>(val ring: Ring<T>) : FunctionalExpressionSpace<T>(ring), Ring<Expression<T>> {
|
||||||
val field: Field<T>
|
|
||||||
) : Field<Expression<T>>, ExpressionContext<T, Expression<T>>, FunctionalExpressionSpace<T>(field) {
|
|
||||||
|
|
||||||
override val one: Expression<T>
|
override val one: Expression<T>
|
||||||
get() = const(this.field.one)
|
get() = const(this.ring.one)
|
||||||
|
|
||||||
fun number(value: Number): Expression<T> = const(field.run { one * value })
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = ProductExpression(field, a, b)
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = DivExpession(field, a, b)
|
fun number(value: Number): Expression<T> = const(ring { one * value })
|
||||||
|
|
||||||
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
||||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
|
||||||
|
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
operator fun T.times(arg: Expression<T>) = arg * this
|
||||||
|
}
|
||||||
|
|
||||||
|
open class FunctionalExpressionField<T>(val field: Field<T>) :
|
||||||
|
FunctionalExpressionRing<T>(field),
|
||||||
|
Field<Expression<T>> {
|
||||||
|
|
||||||
|
override fun unaryOperation(operation: String, arg: Expression<T>): Expression<T> =
|
||||||
|
FunctionalUnaryOperation(algebra, operation, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(algebra, operation, left, right)
|
||||||
|
|
||||||
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
|
FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
|
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
operator fun T.div(arg: Expression<T>) = arg / this
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user