diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts index 4104349f8..3b004eb8e 100644 --- a/kmath-asm/build.gradle.kts +++ b/kmath-asm/build.gradle.kts @@ -6,4 +6,5 @@ dependencies { api(project(path = ":kmath-core")) implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt index 17b6dc023..24260b871 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -31,10 +31,10 @@ open class AsmExpressionSpace( AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) - operator fun AsmExpression.plus(arg: T) = this + const(arg) - operator fun AsmExpression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: AsmExpression) = arg + this - operator fun T.minus(arg: AsmExpression) = arg - this + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this } open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace(ring), Ring> { @@ -52,8 +52,8 @@ open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace< override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmExpression.times(arg: T) = this * const(arg) - operator fun T.times(arg: AsmExpression) = arg * this + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this } open class AsmExpressionField(private val field: Field) : @@ -69,6 +69,6 @@ open class AsmExpressionField(private val field: Field) : override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmExpression.div(arg: T) = this / const(arg) - operator fun T.div(arg: AsmExpression) = arg / this + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt index d7d655d6e..662e1279e 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt @@ -4,28 +4,55 @@ import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Space import scientifik.kmath.operations.invoke +import kotlin.reflect.full.memberFunctions +import kotlin.reflect.jvm.jvmName interface AsmExpression { fun tryEvaluate(): T? = null fun invoke(gen: AsmGenerationContext) } +internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { + context::class.memberFunctions.find { it.name == name && it.parameters.size == arity } + ?: return false + + return true +} + +internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { + context::class.memberFunctions.find { it.name == name && it.parameters.size == arity } + ?: return false + + val owner = context::class.jvmName.replace('.', '/') + + val sig = buildString { + append('(') + repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + append(')') + append("L${AsmGenerationContext.OBJECT_CLASS};") + } + + visitAlgebraOperation(owner = owner, method = name, descriptor = sig) + + return true +} + internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : AsmExpression { private val expr: AsmExpression = expr.optimize() - - override fun tryEvaluate(): T? = context { - unaryOperation( - name, - expr.tryEvaluate() ?: return@context null - ) - } + override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - gen.visitStringConstant(name) + + if (!hasSpecific(context, name, 1)) + gen.visitStringConstant(name) + expr.invoke(gen) + if (gen.tryInvokeSpecific(context, name, 1)) + return + gen.visitAlgebraOperation( owner = AsmGenerationContext.ALGEBRA_CLASS, method = "unaryOperation", @@ -55,10 +82,16 @@ internal class AsmBinaryOperation( override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - gen.visitStringConstant(name) + + if (!hasSpecific(context, name, 1)) + gen.visitStringConstant(name) + first.invoke(gen) second.invoke(gen) + if (gen.tryInvokeSpecific(context, name, 1)) + return + gen.visitAlgebraOperation( owner = AsmGenerationContext.ALGEBRA_CLASS, method = "binaryOperation", @@ -79,7 +112,11 @@ internal class AsmConstantExpression(private val value: T) : AsmExpression override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } -internal class AsmConstProductExpression(private val context: Space, expr: AsmExpression, private val const: Number) : +internal class AsmConstProductExpression( + private val context: Space, + expr: AsmExpression, + private val const: Number +) : AsmExpression { private val expr: AsmExpression = expr.optimize() @@ -100,7 +137,7 @@ internal class AsmConstProductExpression(private val context: Space, expr: internal abstract class FunctionalCompiledExpression internal constructor( @JvmField protected val algebra: Algebra, - @JvmField protected val constants: MutableList + @JvmField protected val constants: Array ) : Expression { abstract override fun invoke(arguments: Map): T } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt index 3e47da5e3..0375b7fc4 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt @@ -46,7 +46,7 @@ class AsmGenerationContext( ) asmCompiledClassWriter.run { - visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { + visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 @@ -60,7 +60,7 @@ class AsmGenerationContext( Opcodes.INVOKESPECIAL, FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "", - "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", + "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", false ) @@ -80,7 +80,7 @@ class AsmGenerationContext( algebraVar ) - visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) + visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) visitMaxs(3, 3) visitEnd() } @@ -170,7 +170,7 @@ class AsmGenerationContext( .defineClass(className, asmCompiledClassWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants) as FunctionalCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression generatedInstance = new return new @@ -184,9 +184,9 @@ class AsmGenerationContext( invokeMethodVisitor.run { visitLoadThis() - visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") + visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;") visitLdcOrIConstInsn(idx) - visitMethodInsn(Opcodes.INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) + visitInsn(Opcodes.AALOAD) invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) } } @@ -273,8 +273,9 @@ class AsmGenerationContext( java.lang.Double::class.java to "D" ) - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" - internal const val LIST_CLASS = "java/util/List" + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = + "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" + internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt index d55555567..663bd55ba 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt @@ -1,10 +1,8 @@ package scientifik.kmath.expressions.asm import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.operations.* @PublishedApi internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { @@ -20,7 +18,11 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } -inline fun asm(i: I, algebra: Algebra, block: I.() -> AsmExpression): Expression { +inline fun >> asm( + i: E, + algebra: Algebra, + block: E.() -> AsmExpression +): Expression { val expression = i.block().optimize() val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) expression.invoke(ctx) diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt index db57a690c..bcd182dc3 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt @@ -1,9 +1,7 @@ package scientifik.kmath.expressions.asm -import scientifik.kmath.expressions.asm.AsmConstantExpression -import scientifik.kmath.expressions.asm.AsmExpression - -fun AsmExpression.optimize(): AsmExpression { +@PublishedApi +internal fun AsmExpression.optimize(): AsmExpression { val a = tryEvaluate() return if (a == null) this else AsmConstantExpression(a) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 73745f78f..8d8bd3a7a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -26,11 +26,9 @@ interface ExpressionContext : Algebra { fun const(value: T): E } -fun ExpressionContext.produce(node: SyntaxTreeNode): E { - return when (node) { - is NumberNode -> error("Single number nodes are not supported") - is SingularNode -> variable(node.value) - is UnaryNode -> unaryOperation(node.operation, produce(node.value)) - is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) - } -} \ No newline at end of file +fun ExpressionContext.produce(node: SyntaxTreeNode): E = when (node) { + is NumberNode -> error("Single number nodes are not supported") + is SingularNode -> variable(node.value) + is UnaryNode -> unaryOperation(node.operation, produce(node.value)) + is BinaryNode -> binaryOperation(node.operation, produce(node.left), produce(node.right)) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index b36ea8b52..d20a5dbbf 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -59,10 +59,10 @@ open class FunctionalExpressionSpace(val space: Space) : FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) override fun multiply(a: Expression, k: Number): Expression = FunctionalConstProductExpression(space, a, k) - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this } open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpressionSpace(ring), Ring> { @@ -80,8 +80,8 @@ open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpression override fun multiply(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) - operator fun Expression.times(arg: T) = this * const(arg) - operator fun T.times(arg: Expression) = arg * this + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun T.times(arg: Expression): Expression = arg * this } open class FunctionalExpressionField(val field: Field) : @@ -97,6 +97,6 @@ open class FunctionalExpressionField(val field: Field) : override fun divide(a: Expression, b: Expression): Expression = FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) - operator fun Expression.div(arg: T) = this / const(arg) - operator fun T.div(arg: Expression) = arg / this + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.div(arg: Expression): Expression = arg / this }