forked from kscience/kmath
Refactor, replace constants List with Array, create specification of named functions
This commit is contained in:
parent
1b6a0a13d8
commit
a0453da4b3
@ -6,4 +6,5 @@ dependencies {
|
|||||||
api(project(path = ":kmath-core"))
|
api(project(path = ":kmath-core"))
|
||||||
implementation("org.ow2.asm:asm:8.0.1")
|
implementation("org.ow2.asm:asm:8.0.1")
|
||||||
implementation("org.ow2.asm:asm-commons:8.0.1")
|
implementation("org.ow2.asm:asm-commons:8.0.1")
|
||||||
|
implementation(kotlin("reflect"))
|
||||||
}
|
}
|
||||||
|
@ -31,10 +31,10 @@ open class AsmExpressionSpace<T>(
|
|||||||
AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b)
|
AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(space, a, k)
|
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(space, a, k)
|
||||||
operator fun AsmExpression<T>.plus(arg: T) = this + const(arg)
|
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
|
||||||
operator fun AsmExpression<T>.minus(arg: T) = this - const(arg)
|
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
|
||||||
operator fun T.plus(arg: AsmExpression<T>) = arg + this
|
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
|
||||||
operator fun T.minus(arg: AsmExpression<T>) = arg - this
|
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
|
||||||
}
|
}
|
||||||
|
|
||||||
open class AsmExpressionRing<T>(private val ring: Ring<T>) : AsmExpressionSpace<T>(ring), Ring<AsmExpression<T>> {
|
open class AsmExpressionRing<T>(private val ring: Ring<T>) : AsmExpressionSpace<T>(ring), Ring<AsmExpression<T>> {
|
||||||
@ -52,8 +52,8 @@ open class AsmExpressionRing<T>(private val ring: Ring<T>) : AsmExpressionSpace<
|
|||||||
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b)
|
AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
operator fun AsmExpression<T>.times(arg: T) = this * const(arg)
|
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
|
||||||
operator fun T.times(arg: AsmExpression<T>) = arg * this
|
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
|
||||||
}
|
}
|
||||||
|
|
||||||
open class AsmExpressionField<T>(private val field: Field<T>) :
|
open class AsmExpressionField<T>(private val field: Field<T>) :
|
||||||
@ -69,6 +69,6 @@ open class AsmExpressionField<T>(private val field: Field<T>) :
|
|||||||
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
|
||||||
AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b)
|
AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
operator fun AsmExpression<T>.div(arg: T) = this / const(arg)
|
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
|
||||||
operator fun T.div(arg: AsmExpression<T>) = arg / this
|
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
|
||||||
}
|
}
|
||||||
|
@ -4,28 +4,55 @@ import scientifik.kmath.expressions.Expression
|
|||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.operations.Algebra
|
||||||
import scientifik.kmath.operations.Space
|
import scientifik.kmath.operations.Space
|
||||||
import scientifik.kmath.operations.invoke
|
import scientifik.kmath.operations.invoke
|
||||||
|
import kotlin.reflect.full.memberFunctions
|
||||||
|
import kotlin.reflect.jvm.jvmName
|
||||||
|
|
||||||
interface AsmExpression<T> {
|
interface AsmExpression<T> {
|
||||||
fun tryEvaluate(): T? = null
|
fun tryEvaluate(): T? = null
|
||||||
fun invoke(gen: AsmGenerationContext<T>)
|
fun invoke(gen: AsmGenerationContext<T>)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun <T> hasSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
|
context::class.memberFunctions.find { it.name == name && it.parameters.size == arity }
|
||||||
|
?: return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun <T> AsmGenerationContext<T>.tryInvokeSpecific(context: Algebra<T>, name: String, arity: Int): Boolean {
|
||||||
|
context::class.memberFunctions.find { it.name == name && it.parameters.size == arity }
|
||||||
|
?: return false
|
||||||
|
|
||||||
|
val owner = context::class.jvmName.replace('.', '/')
|
||||||
|
|
||||||
|
val sig = buildString {
|
||||||
|
append('(')
|
||||||
|
repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") }
|
||||||
|
append(')')
|
||||||
|
append("L${AsmGenerationContext.OBJECT_CLASS};")
|
||||||
|
}
|
||||||
|
|
||||||
|
visitAlgebraOperation(owner = owner, method = name, descriptor = sig)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
|
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
|
||||||
AsmExpression<T> {
|
AsmExpression<T> {
|
||||||
private val expr: AsmExpression<T> = expr.optimize()
|
private val expr: AsmExpression<T> = expr.optimize()
|
||||||
|
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
|
||||||
override fun tryEvaluate(): T? = context {
|
|
||||||
unaryOperation(
|
|
||||||
name,
|
|
||||||
expr.tryEvaluate() ?: return@context null
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
gen.visitStringConstant(name)
|
|
||||||
|
if (!hasSpecific(context, name, 1))
|
||||||
|
gen.visitStringConstant(name)
|
||||||
|
|
||||||
expr.invoke(gen)
|
expr.invoke(gen)
|
||||||
|
|
||||||
|
if (gen.tryInvokeSpecific(context, name, 1))
|
||||||
|
return
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
gen.visitAlgebraOperation(
|
||||||
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
||||||
method = "unaryOperation",
|
method = "unaryOperation",
|
||||||
@ -55,10 +82,16 @@ internal class AsmBinaryOperation<T>(
|
|||||||
|
|
||||||
override fun invoke(gen: AsmGenerationContext<T>) {
|
override fun invoke(gen: AsmGenerationContext<T>) {
|
||||||
gen.visitLoadAlgebra()
|
gen.visitLoadAlgebra()
|
||||||
gen.visitStringConstant(name)
|
|
||||||
|
if (!hasSpecific(context, name, 1))
|
||||||
|
gen.visitStringConstant(name)
|
||||||
|
|
||||||
first.invoke(gen)
|
first.invoke(gen)
|
||||||
second.invoke(gen)
|
second.invoke(gen)
|
||||||
|
|
||||||
|
if (gen.tryInvokeSpecific(context, name, 1))
|
||||||
|
return
|
||||||
|
|
||||||
gen.visitAlgebraOperation(
|
gen.visitAlgebraOperation(
|
||||||
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
owner = AsmGenerationContext.ALGEBRA_CLASS,
|
||||||
method = "binaryOperation",
|
method = "binaryOperation",
|
||||||
@ -79,7 +112,11 @@ internal class AsmConstantExpression<T>(private val value: T) : AsmExpression<T>
|
|||||||
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
override fun invoke(gen: AsmGenerationContext<T>): Unit = gen.visitLoadFromConstants(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class AsmConstProductExpression<T>(private val context: Space<T>, expr: AsmExpression<T>, private val const: Number) :
|
internal class AsmConstProductExpression<T>(
|
||||||
|
private val context: Space<T>,
|
||||||
|
expr: AsmExpression<T>,
|
||||||
|
private val const: Number
|
||||||
|
) :
|
||||||
AsmExpression<T> {
|
AsmExpression<T> {
|
||||||
private val expr: AsmExpression<T> = expr.optimize()
|
private val expr: AsmExpression<T> = expr.optimize()
|
||||||
|
|
||||||
@ -100,7 +137,7 @@ internal class AsmConstProductExpression<T>(private val context: Space<T>, expr:
|
|||||||
|
|
||||||
internal abstract class FunctionalCompiledExpression<T> internal constructor(
|
internal abstract class FunctionalCompiledExpression<T> internal constructor(
|
||||||
@JvmField protected val algebra: Algebra<T>,
|
@JvmField protected val algebra: Algebra<T>,
|
||||||
@JvmField protected val constants: MutableList<out Any>
|
@JvmField protected val constants: Array<Any>
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
abstract override fun invoke(arguments: Map<String, T>): T
|
abstract override fun invoke(arguments: Map<String, T>): T
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ class AsmGenerationContext<T>(
|
|||||||
)
|
)
|
||||||
|
|
||||||
asmCompiledClassWriter.run {
|
asmCompiledClassWriter.run {
|
||||||
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run {
|
visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run {
|
||||||
val thisVar = 0
|
val thisVar = 0
|
||||||
val algebraVar = 1
|
val algebraVar = 1
|
||||||
val constantsVar = 2
|
val constantsVar = 2
|
||||||
@ -60,7 +60,7 @@ class AsmGenerationContext<T>(
|
|||||||
Opcodes.INVOKESPECIAL,
|
Opcodes.INVOKESPECIAL,
|
||||||
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
FUNCTIONAL_COMPILED_EXPRESSION_CLASS,
|
||||||
"<init>",
|
"<init>",
|
||||||
"(L$ALGEBRA_CLASS;L$LIST_CLASS;)V",
|
"(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V",
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ class AsmGenerationContext<T>(
|
|||||||
algebraVar
|
algebraVar
|
||||||
)
|
)
|
||||||
|
|
||||||
visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS<L$T_CLASS;>;", l0, l2, constantsVar)
|
visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar)
|
||||||
visitMaxs(3, 3)
|
visitMaxs(3, 3)
|
||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
@ -170,7 +170,7 @@ class AsmGenerationContext<T>(
|
|||||||
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
.defineClass(className, asmCompiledClassWriter.toByteArray())
|
||||||
.constructors
|
.constructors
|
||||||
.first()
|
.first()
|
||||||
.newInstance(algebra, constants) as FunctionalCompiledExpression<T>
|
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
|
||||||
|
|
||||||
generatedInstance = new
|
generatedInstance = new
|
||||||
return new
|
return new
|
||||||
@ -184,9 +184,9 @@ class AsmGenerationContext<T>(
|
|||||||
|
|
||||||
invokeMethodVisitor.run {
|
invokeMethodVisitor.run {
|
||||||
visitLoadThis()
|
visitLoadThis()
|
||||||
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;")
|
visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;")
|
||||||
visitLdcOrIConstInsn(idx)
|
visitLdcOrIConstInsn(idx)
|
||||||
visitMethodInsn(Opcodes.INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true)
|
visitInsn(Opcodes.AALOAD)
|
||||||
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type)
|
invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -273,8 +273,9 @@ class AsmGenerationContext<T>(
|
|||||||
java.lang.Double::class.java to "D"
|
java.lang.Double::class.java to "D"
|
||||||
)
|
)
|
||||||
|
|
||||||
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/asm/FunctionalCompiledExpression"
|
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
||||||
internal const val LIST_CLASS = "java/util/List"
|
"scientifik/kmath/expressions/asm/FunctionalCompiledExpression"
|
||||||
|
|
||||||
internal const val MAP_CLASS = "java/util/Map"
|
internal const val MAP_CLASS = "java/util/Map"
|
||||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
package scientifik.kmath.expressions.asm
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
import scientifik.kmath.expressions.Expression
|
import scientifik.kmath.expressions.Expression
|
||||||
import scientifik.kmath.operations.Algebra
|
import scientifik.kmath.expressions.ExpressionContext
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.*
|
||||||
import scientifik.kmath.operations.Ring
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
|
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
|
||||||
@ -20,7 +18,11 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
inline fun <reified T, I> asm(i: I, algebra: Algebra<T>, block: I.() -> AsmExpression<T>): Expression<T> {
|
inline fun <reified T, E : ExpressionContext<T, AsmExpression<T>>> asm(
|
||||||
|
i: E,
|
||||||
|
algebra: Algebra<T>,
|
||||||
|
block: E.() -> AsmExpression<T>
|
||||||
|
): Expression<T> {
|
||||||
val expression = i.block().optimize()
|
val expression = i.block().optimize()
|
||||||
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression))
|
||||||
expression.invoke(ctx)
|
expression.invoke(ctx)
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
package scientifik.kmath.expressions.asm
|
package scientifik.kmath.expressions.asm
|
||||||
|
|
||||||
import scientifik.kmath.expressions.asm.AsmConstantExpression
|
@PublishedApi
|
||||||
import scientifik.kmath.expressions.asm.AsmExpression
|
internal fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
|
||||||
|
|
||||||
fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
|
|
||||||
val a = tryEvaluate()
|
val a = tryEvaluate()
|
||||||
return if (a == null) this else AsmConstantExpression(a)
|
return if (a == null) this else AsmConstantExpression(a)
|
||||||
}
|
}
|
||||||
|
@ -26,11 +26,9 @@ interface ExpressionContext<T, E> : Algebra<E> {
|
|||||||
fun const(value: T): E
|
fun const(value: T): E
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T, E> ExpressionContext<T, E>.produce(node: SyntaxTreeNode): E {
|
fun <T, E> ExpressionContext<T, E>.produce(node: SyntaxTreeNode): E = when (node) {
|
||||||
return when (node) {
|
is NumberNode -> error("Single number nodes are not supported")
|
||||||
is NumberNode -> error("Single number nodes are not supported")
|
is SingularNode -> variable(node.value)
|
||||||
is SingularNode -> variable(node.value)
|
is UnaryNode -> unaryOperation(node.operation, produce(node.value))
|
||||||
is UnaryNode -> unaryOperation(node.operation, produce(node.value))
|
is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right))
|
||||||
is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right))
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -59,10 +59,10 @@ open class FunctionalExpressionSpace<T>(val space: Space<T>) :
|
|||||||
FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b)
|
FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: Expression<T>, k: Number): Expression<T> = FunctionalConstProductExpression(space, a, k)
|
override fun multiply(a: Expression<T>, k: Number): Expression<T> = FunctionalConstProductExpression(space, a, k)
|
||||||
operator fun Expression<T>.plus(arg: T) = this + const(arg)
|
operator fun Expression<T>.plus(arg: T): Expression<T> = this + const(arg)
|
||||||
operator fun Expression<T>.minus(arg: T) = this - const(arg)
|
operator fun Expression<T>.minus(arg: T): Expression<T> = this - const(arg)
|
||||||
operator fun T.plus(arg: Expression<T>) = arg + this
|
operator fun T.plus(arg: Expression<T>): Expression<T> = arg + this
|
||||||
operator fun T.minus(arg: Expression<T>) = arg - this
|
operator fun T.minus(arg: Expression<T>): Expression<T> = arg - this
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionRing<T>(val ring: Ring<T>) : FunctionalExpressionSpace<T>(ring), Ring<Expression<T>> {
|
open class FunctionalExpressionRing<T>(val ring: Ring<T>) : FunctionalExpressionSpace<T>(ring), Ring<Expression<T>> {
|
||||||
@ -80,8 +80,8 @@ open class FunctionalExpressionRing<T>(val ring: Ring<T>) : FunctionalExpression
|
|||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b)
|
FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
operator fun Expression<T>.times(arg: T) = this * const(arg)
|
operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
operator fun T.times(arg: Expression<T>) = arg * this
|
operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
}
|
}
|
||||||
|
|
||||||
open class FunctionalExpressionField<T>(val field: Field<T>) :
|
open class FunctionalExpressionField<T>(val field: Field<T>) :
|
||||||
@ -97,6 +97,6 @@ open class FunctionalExpressionField<T>(val field: Field<T>) :
|
|||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b)
|
FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
operator fun Expression<T>.div(arg: T) = this / const(arg)
|
operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
operator fun T.div(arg: Expression<T>) = arg / this
|
operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user