diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt index 14456a426..69e49f8b6 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt @@ -1,12 +1,13 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.AsmGenerationContext import scientifik.kmath.ast.MST import scientifik.kmath.ast.evaluate import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.* @PublishedApi -internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { +internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" try { @@ -19,15 +20,16 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String } @PublishedApi -internal inline fun AsmExpression.compile(algebra: Algebra): Expression { - val ctx = AsmGenerationContext(T::class.java, algebra, buildName(this)) +internal inline fun AsmNode.compile(algebra: Algebra): Expression { + val ctx = + AsmGenerationContext(T::class.java, algebra, buildName(this)) compile(ctx) return ctx.generate() } inline fun , E : AsmExpressionAlgebra> A.asm( expressionAlgebra: E, - block: E.() -> AsmExpression + block: E.() -> AsmNode ): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) inline fun , E : AsmExpressionAlgebra> A.asm( @@ -35,19 +37,19 @@ inline fun , E : AsmExpressionAlgebra> A. ast: MST ): Expression = asm(expressionAlgebra) { evaluate(ast) } -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmNode): Expression where A : NumericAlgebra, A : Space = AsmExpressionSpace(this).let { it.block().compile(it.algebra) } inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = asmSpace { evaluate(ast) } -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmNode): Expression where A : NumericAlgebra, A : Ring = AsmExpressionRing(this).let { it.block().compile(it.algebra) } inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = asmRing { evaluate(ast) } -inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = +inline fun A.asmField(block: AsmExpressionField.() -> AsmNode): Expression where A : NumericAlgebra, A : Field = AsmExpressionField(this).let { it.block().compile(it.algebra) } inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index cd1815f7c..8f16a2710 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,5 +1,6 @@ package scientifik.kmath.asm +import scientifik.kmath.asm.internal.AsmGenerationContext import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.optimize import scientifik.kmath.asm.internal.tryInvokeSpecific @@ -12,25 +13,26 @@ import scientifik.kmath.operations.* * * @param T the type the stored function returns. */ -interface AsmExpression { +abstract class AsmNode internal constructor() { /** * Tries to evaluate this function without its variables. This method is intended for optimization. * * @return `null` if the function depends on its variables, the value if the function is a constant. */ - fun tryEvaluate(): T? = null + internal open fun tryEvaluate(): T? = null /** * Compiles this declaration. * * @param gen the target [AsmGenerationContext]. */ - fun compile(gen: AsmGenerationContext) + @PublishedApi + internal abstract fun compile(gen: AsmGenerationContext) } -internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : - AsmExpression { - private val expr: AsmExpression = expr.optimize() +internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : + AsmNode() { + private val expr: AsmNode = expr.optimize() override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } override fun compile(gen: AsmGenerationContext) { @@ -57,11 +59,11 @@ internal class AsmUnaryOperation(private val context: Algebra, private val internal class AsmBinaryOperation( private val context: Algebra, private val name: String, - first: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() + first: AsmNode, + second: AsmNode +) : AsmNode() { + private val first: AsmNode = first.optimize() + private val second: AsmNode = second.optimize() override fun tryEvaluate(): T? = context { binaryOperation( @@ -95,23 +97,22 @@ internal class AsmBinaryOperation( } internal class AsmVariableExpression(private val name: String, private val default: T? = null) : - AsmExpression { + AsmNode() { override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) } internal class AsmConstantExpression(private val value: T) : - AsmExpression { + AsmNode() { override fun tryEvaluate(): T = value override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) } internal class AsmConstProductExpression( private val context: Space, - expr: AsmExpression, + expr: AsmNode, private val const: Number -) : - AsmExpression { - private val expr: AsmExpression = expr.optimize() +) : AsmNode() { + private val expr: AsmNode = expr.optimize() override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } @@ -131,7 +132,7 @@ internal class AsmConstProductExpression( } internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : - AsmExpression { + AsmNode() { override fun tryEvaluate(): T? = context.number(value) override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) @@ -145,10 +146,10 @@ internal abstract class FunctionalCompiledExpression internal constructor( } /** - * A context class for [AsmExpression] construction. + * A context class for [AsmNode] construction. */ -interface AsmExpressionAlgebra> : NumericAlgebra>, - ExpressionAlgebra> { +interface AsmExpressionAlgebra> : NumericAlgebra>, + ExpressionAlgebra> { /** * The algebra to provide for AsmExpressions built. */ @@ -157,108 +158,108 @@ interface AsmExpressionAlgebra> : NumericAlgebra = AsmNumberExpression(algebra, value) + override fun number(value: Number): AsmNode = AsmNumberExpression(algebra, value) /** * Builds an AsmExpression of constant expression which does not depend on arguments. */ - override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun const(value: T): AsmNode = AsmConstantExpression(value) /** * Builds an AsmExpression to access a variable. */ - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + override fun variable(name: String, default: T?): AsmNode = AsmVariableExpression(name, default) /** * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. */ - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = AsmBinaryOperation(algebra, operation, left, right) /** * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. */ - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = AsmUnaryOperation(algebra, operation, arg) } /** - * A context class for [AsmExpression] construction for [Space] algebras. + * A context class for [AsmNode] construction for [Space] algebras. */ open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, - Space> where A : Space, A : NumericAlgebra { - override val zero: AsmExpression + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmNode get() = const(algebra.zero) /** * Builds an AsmExpression of addition of two another expressions. */ - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + override fun add(a: AsmNode, b: AsmNode): AsmNode = AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) /** * Builds an AsmExpression of multiplication of expression by number. */ - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(algebra, a, k) + override fun multiply(a: AsmNode, k: Number): AsmNode = AsmConstProductExpression(algebra, a, k) - operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) - operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) - operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this - operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this + operator fun AsmNode.plus(arg: T): AsmNode = this + const(arg) + operator fun AsmNode.minus(arg: T): AsmNode = this - const(arg) + operator fun T.plus(arg: AsmNode): AsmNode = arg + this + operator fun T.minus(arg: AsmNode): AsmNode = arg - this - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = super.binaryOperation(operation, left, right) } /** - * A context class for [AsmExpression] construction for [Ring] algebras. + * A context class for [AsmNode] construction for [Ring] algebras. */ open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: AsmExpression + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmNode get() = const(algebra.one) /** * Builds an AsmExpression of multiplication of two expressions. */ - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + override fun multiply(a: AsmNode, b: AsmNode): AsmNode = AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) - operator fun T.times(arg: AsmExpression): AsmExpression = arg * this + operator fun AsmNode.times(arg: T): AsmNode = this * const(arg) + operator fun T.times(arg: AsmNode): AsmNode = arg * this - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmExpression = super.number(value) + override fun number(value: Number): AsmNode = super.number(value) } /** - * A context class for [AsmExpression] construction for [Field] algebras. + * A context class for [AsmNode] construction for [Field] algebras. */ open class AsmExpressionField(override val algebra: A) : AsmExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { + Field> where A : Field, A : NumericAlgebra { /** * Builds an AsmExpression of division an expression by another one. */ - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + override fun divide(a: AsmNode, b: AsmNode): AsmNode = AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) - operator fun T.div(arg: AsmExpression): AsmExpression = arg / this + operator fun AsmNode.div(arg: T): AsmNode = this / const(arg) + operator fun T.div(arg: AsmNode): AsmNode = arg / this - override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmExpression = super.number(value) + override fun number(value: Number): AsmNode = super.number(value) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt similarity index 96% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt index be8da3eff..fa0f06579 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt @@ -1,17 +1,15 @@ -package scientifik.kmath.asm +package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes -import scientifik.kmath.asm.AsmGenerationContext.ClassLoader -import scientifik.kmath.asm.internal.visitLdcOrDConstInsn -import scientifik.kmath.asm.internal.visitLdcOrFConstInsn -import scientifik.kmath.asm.internal.visitLdcOrIConstInsn +import scientifik.kmath.asm.FunctionalCompiledExpression +import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader import scientifik.kmath.operations.Algebra /** - * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java + * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new * class. * @@ -19,7 +17,8 @@ import scientifik.kmath.operations.Algebra * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -class AsmGenerationContext @PublishedApi internal constructor( +@PublishedApi +internal class AsmGenerationContext @PublishedApi internal constructor( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt index cd41ff07c..815f292ce 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt @@ -2,8 +2,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.asm.AsmConstantExpression -import scientifik.kmath.asm.AsmExpression -import scientifik.kmath.asm.AsmGenerationContext +import scientifik.kmath.asm.AsmNode import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") @@ -44,7 +43,7 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, } @PublishedApi -internal fun AsmExpression.optimize(): AsmExpression { +internal fun AsmNode.optimize(): AsmNode { val a = tryEvaluate() return if (a == null) this else AsmConstantExpression(a) }