diff --git a/kmath-ast/README.md b/kmath-ast/README.md new file mode 100644 index 000000000..f9dbd4663 --- /dev/null +++ b/kmath-ast/README.md @@ -0,0 +1,4 @@ +# AST based expression representation and operations + +## Dynamic expression code generation +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). \ No newline at end of file diff --git a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt index 3375ecdf3..900b9297a 100644 --- a/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/scientifik/kmath/ast/MST.kt @@ -11,7 +11,7 @@ sealed class MST { /** * A node containing unparsed string */ - data class Singular(val value: String) : MST() + data class Symbolic(val value: String) : MST() /** * A node containing a number @@ -43,7 +43,7 @@ sealed class MST { fun NumericAlgebra.evaluate(node: MST): T { return when (node) { is MST.Numeric -> number(node.value) - is MST.Singular -> symbol(node.value) + is MST.Symbolic -> symbol(node.value) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Binary -> when { node.left is MST.Numeric && node.right is MST.Numeric -> { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt deleted file mode 100644 index 69e49f8b6..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt +++ /dev/null @@ -1,56 +0,0 @@ -package scientifik.kmath.asm - -import scientifik.kmath.asm.internal.AsmGenerationContext -import scientifik.kmath.ast.MST -import scientifik.kmath.ast.evaluate -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.* - -@PublishedApi -internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(expression, collision + 1) -} - -@PublishedApi -internal inline fun AsmNode.compile(algebra: Algebra): Expression { - val ctx = - AsmGenerationContext(T::class.java, algebra, buildName(this)) - compile(ctx) - return ctx.generate() -} - -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - block: E.() -> AsmNode -): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) - -inline fun , E : AsmExpressionAlgebra> A.asm( - expressionAlgebra: E, - ast: MST -): Expression = asm(expressionAlgebra) { evaluate(ast) } - -inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmNode): Expression where A : NumericAlgebra, A : Space = - AsmExpressionSpace(this).let { it.block().compile(it.algebra) } - -inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = - asmSpace { evaluate(ast) } - -inline fun A.asmRing(block: AsmExpressionRing.() -> AsmNode): Expression where A : NumericAlgebra, A : Ring = - AsmExpressionRing(this).let { it.block().compile(it.algebra) } - -inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = - asmRing { evaluate(ast) } - -inline fun A.asmField(block: AsmExpressionField.() -> AsmNode): Expression where A : NumericAlgebra, A : Field = - AsmExpressionField(this).let { it.block().compile(it.algebra) } - -inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = - asmRing { evaluate(ast) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt index 8f16a2710..a8b3ff976 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -1,19 +1,24 @@ package scientifik.kmath.asm -import scientifik.kmath.asm.internal.AsmGenerationContext +import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.hasSpecific import scientifik.kmath.asm.internal.optimize import scientifik.kmath.asm.internal.tryInvokeSpecific import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.ExpressionAlgebra import scientifik.kmath.operations.* +import kotlin.reflect.KClass /** - * A function declaration that could be compiled to [AsmGenerationContext]. + * A function declaration that could be compiled to [AsmBuilder]. * * @param T the type the stored function returns. */ -abstract class AsmNode internal constructor() { +sealed class AsmExpression: Expression { + abstract val type: KClass + + abstract val algebra: Algebra + /** * Tries to evaluate this function without its variables. This method is intended for optimization. * @@ -24,118 +29,144 @@ abstract class AsmNode internal constructor() { /** * Compiles this declaration. * - * @param gen the target [AsmGenerationContext]. + * @param gen the target [AsmBuilder]. */ - @PublishedApi - internal abstract fun compile(gen: AsmGenerationContext) + internal abstract fun appendTo(gen: AsmBuilder) + + /** + * Compile and cache the expression + */ + private val compiledExpression by lazy{ + val builder = AsmBuilder(type.java, algebra, buildName(this)) + this.appendTo(builder) + builder.generate() + } + + override fun invoke(arguments: Map): T = compiledExpression.invoke(arguments) } -internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : - AsmNode() { - private val expr: AsmNode = expr.optimize() - override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } +internal class AsmUnaryOperation( + override val type: KClass, + override val algebra: Algebra, + private val name: String, + expr: AsmExpression +) : AsmExpression() { + private val expr: AsmExpression = expr.optimize() + override fun tryEvaluate(): T? = algebra { unaryOperation(name, expr.tryEvaluate() ?: return@algebra null) } - override fun compile(gen: AsmGenerationContext) { + override fun appendTo(gen: AsmBuilder) { gen.visitLoadAlgebra() - if (!hasSpecific(context, name, 1)) + if (!hasSpecific(algebra, name, 1)) gen.visitStringConstant(name) - expr.compile(gen) + expr.appendTo(gen) - if (gen.tryInvokeSpecific(context, name, 1)) + if (gen.tryInvokeSpecific(algebra, name, 1)) return gen.visitAlgebraOperation( - owner = AsmGenerationContext.ALGEBRA_CLASS, + owner = AsmBuilder.ALGEBRA_CLASS, method = "unaryOperation", - descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } -internal class AsmBinaryOperation( - private val context: Algebra, +internal class AsmBinaryOperation( + override val type: KClass, + override val algebra: Algebra, private val name: String, - first: AsmNode, - second: AsmNode -) : AsmNode() { - private val first: AsmNode = first.optimize() - private val second: AsmNode = second.optimize() + first: AsmExpression, + second: AsmExpression +) : AsmExpression() { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() - override fun tryEvaluate(): T? = context { + override fun tryEvaluate(): T? = algebra { binaryOperation( name, - first.tryEvaluate() ?: return@context null, - second.tryEvaluate() ?: return@context null + first.tryEvaluate() ?: return@algebra null, + second.tryEvaluate() ?: return@algebra null ) } - override fun compile(gen: AsmGenerationContext) { + override fun appendTo(gen: AsmBuilder) { gen.visitLoadAlgebra() - if (!hasSpecific(context, name, 2)) + if (!hasSpecific(algebra, name, 2)) gen.visitStringConstant(name) - first.compile(gen) - second.compile(gen) + first.appendTo(gen) + second.appendTo(gen) - if (gen.tryInvokeSpecific(context, name, 2)) + if (gen.tryInvokeSpecific(algebra, name, 2)) return gen.visitAlgebraOperation( - owner = AsmGenerationContext.ALGEBRA_CLASS, + owner = AsmBuilder.ALGEBRA_CLASS, method = "binaryOperation", - descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};" + - "L${AsmGenerationContext.OBJECT_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.STRING_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.OBJECT_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } -internal class AsmVariableExpression(private val name: String, private val default: T? = null) : - AsmNode() { - override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) +internal class AsmVariableExpression( + override val type: KClass, + override val algebra: Algebra, + private val name: String, + private val default: T? = null +) : AsmExpression() { + override fun appendTo(gen: AsmBuilder): Unit = gen.visitLoadFromVariables(name, default) } -internal class AsmConstantExpression(private val value: T) : - AsmNode() { +internal class AsmConstantExpression( + override val type: KClass, + override val algebra: Algebra, + private val value: T +) : AsmExpression() { override fun tryEvaluate(): T = value - override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) + override fun appendTo(gen: AsmBuilder): Unit = gen.visitLoadFromConstants(value) } -internal class AsmConstProductExpression( - private val context: Space, - expr: AsmNode, +internal class AsmConstProductExpression( + override val type: KClass, + override val algebra: Space, + expr: AsmExpression, private val const: Number -) : AsmNode() { - private val expr: AsmNode = expr.optimize() +) : AsmExpression() { + private val expr: AsmExpression = expr.optimize() - override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } + override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } - override fun compile(gen: AsmGenerationContext) { + override fun appendTo(gen: AsmBuilder) { gen.visitLoadAlgebra() gen.visitNumberConstant(const) - expr.compile(gen) + expr.appendTo(gen) gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + owner = AsmBuilder.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" + - "L${AsmGenerationContext.NUMBER_CLASS};)" + - "L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmBuilder.OBJECT_CLASS};" + + "L${AsmBuilder.NUMBER_CLASS};)" + + "L${AsmBuilder.OBJECT_CLASS};" ) } } -internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : - AsmNode() { - override fun tryEvaluate(): T? = context.number(value) +internal class AsmNumberExpression( + override val type: KClass, + override val algebra: NumericAlgebra, + private val value: Number +) : AsmExpression() { + override fun tryEvaluate(): T? = algebra.number(value) - override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) + override fun appendTo(gen: AsmBuilder): Unit = gen.visitNumberConstant(value) } internal abstract class FunctionalCompiledExpression internal constructor( @@ -146,120 +177,116 @@ internal abstract class FunctionalCompiledExpression internal constructor( } /** - * A context class for [AsmNode] construction. + * A context class for [AsmExpression] construction. + * + * @param algebra The algebra to provide for AsmExpressions built. */ -interface AsmExpressionAlgebra> : NumericAlgebra>, - ExpressionAlgebra> { - /** - * The algebra to provide for AsmExpressions built. - */ - val algebra: A +open class AsmExpressionAlgebra>(val type: KClass, val algebra: A) : + NumericAlgebra>, ExpressionAlgebra> { /** * Builds an AsmExpression to wrap a number. */ - override fun number(value: Number): AsmNode = AsmNumberExpression(algebra, value) + override fun number(value: Number): AsmExpression = AsmNumberExpression(type, algebra, value) /** * Builds an AsmExpression of constant expression which does not depend on arguments. */ - override fun const(value: T): AsmNode = AsmConstantExpression(value) + override fun const(value: T): AsmExpression = AsmConstantExpression(type, algebra, value) /** * Builds an AsmExpression to access a variable. */ - override fun variable(name: String, default: T?): AsmNode = AsmVariableExpression(name, default) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(type, algebra, name, default) /** * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. */ - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = - AsmBinaryOperation(algebra, operation, left, right) + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, operation, left, right) /** * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. */ - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = - AsmUnaryOperation(algebra, operation, arg) + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(type, algebra, operation, arg) } /** - * A context class for [AsmNode] construction for [Space] algebras. + * A context class for [AsmExpression] construction for [Space] algebras. */ -open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, - Space> where A : Space, A : NumericAlgebra { - override val zero: AsmNode - get() = const(algebra.zero) +open class AsmExpressionSpace(type: KClass, algebra: A) : AsmExpressionAlgebra(type, algebra), + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmExpression get() = const(algebra.zero) /** * Builds an AsmExpression of addition of two another expressions. */ - override fun add(a: AsmNode, b: AsmNode): AsmNode = - AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, SpaceOperations.PLUS_OPERATION, a, b) /** * Builds an AsmExpression of multiplication of expression by number. */ - override fun multiply(a: AsmNode, k: Number): AsmNode = AsmConstProductExpression(algebra, a, k) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(type, algebra, a, k) - operator fun AsmNode.plus(arg: T): AsmNode = this + const(arg) - operator fun AsmNode.minus(arg: T): AsmNode = this - const(arg) - operator fun T.plus(arg: AsmNode): AsmNode = arg + this - operator fun T.minus(arg: AsmNode): AsmNode = arg - this + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) } /** - * A context class for [AsmNode] construction for [Ring] algebras. + * A context class for [AsmExpression] construction for [Ring] algebras. */ -open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { - override val one: AsmNode - get() = const(algebra.one) +open class AsmExpressionRing(type: KClass, algebra: A) : AsmExpressionSpace(type, algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmExpression get() = const(algebra.one) /** * Builds an AsmExpression of multiplication of two expressions. */ - override fun multiply(a: AsmNode, b: AsmNode): AsmNode = - AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, RingOperations.TIMES_OPERATION, a, b) - operator fun AsmNode.times(arg: T): AsmNode = this * const(arg) - operator fun T.times(arg: AsmNode): AsmNode = arg * this + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmNode = super.number(value) + override fun number(value: Number): AsmExpression = super.number(value) } /** - * A context class for [AsmNode] construction for [Field] algebras. + * A context class for [AsmExpression] construction for [Field] algebras. */ -open class AsmExpressionField(override val algebra: A) : - AsmExpressionRing(algebra), - Field> where A : Field, A : NumericAlgebra { +open class AsmExpressionField(type: KClass, algebra: A) : + AsmExpressionRing(type, algebra), + Field> where A : Field, A : NumericAlgebra { /** * Builds an AsmExpression of division an expression by another one. */ - override fun divide(a: AsmNode, b: AsmNode): AsmNode = - AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(type, algebra, FieldOperations.DIV_OPERATION, a, b) - operator fun AsmNode.div(arg: T): AsmNode = this / const(arg) - operator fun T.div(arg: AsmNode): AsmNode = arg / this + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this - override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = super.unaryOperation(operation, arg) - override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = super.binaryOperation(operation, left, right) - override fun number(value: Number): AsmNode = super.number(value) + override fun number(value: Number): AsmExpression = super.number(value) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt new file mode 100644 index 000000000..cc3e36e94 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -0,0 +1,47 @@ +package scientifik.kmath.asm + +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.evaluate +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.NumericAlgebra +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + block: E.() -> AsmExpression +): Expression = expressionAlgebra.block() + +inline fun NumericAlgebra.asm(ast: MST): Expression = + AsmExpressionAlgebra(T::class, this).evaluate(ast) + +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmExpression): Expression where A : NumericAlgebra, A : Space = + AsmExpressionSpace(T::class, this).block() + +inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = + asmSpace { evaluate(ast) } + +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmExpression): Expression where A : NumericAlgebra, A : Ring = + AsmExpressionRing(T::class, this).block() + +inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = + asmRing { evaluate(ast) } + +inline fun A.asmField(block: AsmExpressionField.() -> AsmExpression): Expression where A : NumericAlgebra, A : Field = + AsmExpressionField(T::class, this).block() + +inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = + asmRing { evaluate(ast) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt similarity index 82% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index fa0f06579..eefde66e4 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -5,30 +5,29 @@ import org.objectweb.asm.Label import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.FunctionalCompiledExpression -import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader +import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader import scientifik.kmath.operations.Algebra /** * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java - * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new + * expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new * class. * * @param T the type of AsmExpression to unwrap. * @param algebra the algebra the applied AsmExpressions use. * @param className the unique class name of new loaded class. */ -@PublishedApi -internal class AsmGenerationContext @PublishedApi internal constructor( +internal class AsmBuilder( private val classOfT: Class<*>, private val algebra: Algebra, private val className: String ) { - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + private class AsmClassLoader(parent: ClassLoader) : ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } - private val classLoader: ClassLoader = - ClassLoader(javaClass.classLoader) + private val classLoader: AsmClassLoader = + AsmClassLoader(javaClass.classLoader) @Suppress("PrivatePropertyName") private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') @@ -113,9 +112,8 @@ internal class AsmGenerationContext @PublishedApi internal constructor( } } - @PublishedApi @Suppress("UNCHECKED_CAST") - internal fun generate(): FunctionalCompiledExpression { + fun generate(): FunctionalCompiledExpression { generatedInstance?.let { return it } invokeMethodVisitor.run { @@ -188,7 +186,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( return new } - internal fun visitLoadFromConstants(value: T) { + fun visitLoadFromConstants(value: T) { if (classOfT in INLINABLE_NUMBERS) { visitNumberConstant(value as Number) visitCastToT() @@ -213,7 +211,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) - internal fun visitNumberConstant(value: Number) { + fun visitNumberConstant(value: Number) { maxStack++ val clazz = value.javaClass val c = clazz.name.replace('.', '/') @@ -234,7 +232,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( visitLoadAnyFromConstants(value, c) } - internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { maxStack += 2 visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) @@ -262,7 +260,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( visitCastToT() } - internal fun visitLoadAlgebra() { + fun visitLoadAlgebra() { maxStack++ invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) @@ -274,7 +272,7 @@ internal class AsmGenerationContext @PublishedApi internal constructor( invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) } - internal fun visitAlgebraOperation( + fun visitAlgebraOperation( owner: String, method: String, descriptor: String, @@ -288,32 +286,28 @@ internal class AsmGenerationContext @PublishedApi internal constructor( private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) - internal fun visitStringConstant(string: String) { + fun visitStringConstant(string: String) { invokeMethodVisitor.visitLdcInsn(string) } - internal companion object { - private val SIGNATURE_LETTERS by lazy { - mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) - } + companion object { + private val SIGNATURE_LETTERS = mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) - private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys } + private val INLINABLE_NUMBERS = SIGNATURE_LETTERS.keys - internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = - "scientifik/kmath/asm/FunctionalCompiledExpression" - - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" - internal const val STRING_CLASS = "java/lang/String" - internal const val NUMBER_CLASS = "java/lang/Number" + const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/FunctionalCompiledExpression" + const val MAP_CLASS = "java/util/Map" + const val OBJECT_CLASS = "java/lang/Object" + const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + const val STRING_CLASS = "java/lang/String" + const val NUMBER_CLASS = "java/lang/Number" } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt index 815f292ce..d9f950ceb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt @@ -2,7 +2,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.asm.AsmConstantExpression -import scientifik.kmath.asm.AsmNode +import scientifik.kmath.asm.AsmExpression import scientifik.kmath.operations.Algebra private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") @@ -16,7 +16,7 @@ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boo return true } -internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name context::class.java.methods.find { it.name == aName && it.parameters.size == arity } @@ -26,9 +26,9 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, val sig = buildString { append('(') - repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } append(')') - append("L${AsmGenerationContext.OBJECT_CLASS};") + append("L${AsmBuilder.OBJECT_CLASS};") } visitAlgebraOperation( @@ -42,8 +42,7 @@ internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, return true } -@PublishedApi -internal fun AsmNode.optimize(): AsmNode { +internal fun AsmExpression.optimize(): AsmExpression { val a = tryEvaluate() - return if (a == null) this else AsmConstantExpression(a) + return if (a == null) this else AsmConstantExpression(type, algebra, a) } diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt similarity index 98% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index 07f895c40..06f776597 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,4 +1,4 @@ -package scietifik.kmath.ast.asm +package scietifik.kmath.asm import scientifik.kmath.asm.asmField import scientifik.kmath.asm.asmRing diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt similarity index 94% rename from kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt rename to kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt index 2839b8a25..40d990537 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmExpressions.kt @@ -1,4 +1,4 @@ -package scietifik.kmath.ast.asm +package scietifik.kmath.asm import scientifik.kmath.asm.asmField import scientifik.kmath.expressions.invoke diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt similarity index 100% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt index b933f6a17..9eae60efc 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/expressions/ExpressionFieldTest.kt @@ -31,7 +31,7 @@ class ExpressionFieldTest { @Test fun separateContext() { - fun FunctionalExpressionField.expression(): Expression { + fun FunctionalExpressionField.expression(): Expression { val x = variable("x") return x * x + 2 * x + one } @@ -42,7 +42,7 @@ class ExpressionFieldTest { @Test fun valueExpression() { - val expressionBuilder: FunctionalExpressionField.() -> Expression = { + val expressionBuilder: FunctionalExpressionField.() -> Expression = { val x = variable("x") x * x + 2 * x + one }