diff --git a/kmath-ast/README.md b/kmath-ast/README.md index f9dbd4663..b5ca5886f 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -1,4 +1,61 @@ -# AST based expression representation and operations +# AST-based expression representation and operations (`kmath-ast`) -## Dynamic expression code generation -Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). \ No newline at end of file +This subproject implements the following features: + +- Expression Language and its parser. +- MST as expression language's syntax intermediate representation. +- Type-safe builder of MST. +- Evaluating expressions by traversing MST. + +## Dynamic expression code generation with OW2 ASM + +`kmath-ast` JVM module supports runtime code generation to eliminate overhead of tree traversal. Code generator builds +a special implementation of `Expression` with implemented `invoke` function. + +For example, the following builder: + +```kotlin + RealField.mstInField { symbol("x") + 2 }.compile() +``` + +… leads to generation of bytecode, which can be decompiled to the following Java class: + +```java +package scientifik.kmath.asm.generated; + +import java.util.Map; +import scientifik.kmath.asm.internal.AsmCompiledExpression; +import scientifik.kmath.operations.Algebra; +import scientifik.kmath.operations.RealField; + +// The class's name is build with MST's hash-code and collision fixing number. +public final class AsmCompiledExpression_45045_0 extends AsmCompiledExpression { + // Plain constructor + public AsmCompiledExpression_45045_0(Algebra algebra, Object[] constants) { + super(algebra, constants); + } + + // The actual dynamic code: + public final Double invoke(Map arguments) { + return (Double)((RealField)super.algebra).add((Double)arguments.get("x"), (Double)2.0D); + } +} +``` + +### Example Usage + +This API is an extension to MST and MSTExpression APIs. You may optimize both MST and MSTExpression: + +```kotlin +RealField.mstInField { symbol("x") + 2 }.compile() +RealField.expression("2+2".parseMath()) +``` + +### Known issues + +- Using numeric algebras causes boxing and calling bridge methods. +- The same classes may be generated and loaded twice, so it is recommended to cache compiled expressions to avoid +class loading overhead. +- This API is not supported by non-dynamic JVM implementations (like TeaVM and GraalVM) because of using class loaders. + +Contributed by [Iaroslav Postovalov](https://github.com/CommanderTvis). diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index 2345757f1..a09a7ebe3 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -20,7 +20,8 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { val constant = if (algebra is NumericAlgebra) - algebra.number(node.value) else + algebra.number(node.value) + else error("Number literals are not supported in $algebra") loadTConstant(constant) @@ -72,9 +73,13 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< /** * Compile an [MST] to ASM using given algebra */ -inline fun Algebra.expresion(mst: MST): Expression = mst.compileWith(T::class, this) +inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class, this) /** * Optimize performance of an [MSTExpression] using ASM codegen */ -inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) \ No newline at end of file +inline fun MSTExpression.compile(): Expression = mst.compileWith(T::class, algebra) + +fun main() { + RealField.mstInField { symbol("x") + 2 }.compile() +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 351d5b754..77a4ee9e9 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,7 +6,6 @@ import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra -import java.io.File /** * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. @@ -22,10 +21,16 @@ internal class AsmBuilder internal constructor( private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit ) { + /** + * Internal classloader of [AsmBuilder] with alias to define class from byte array. + */ private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } + /** + * The instance of [ClassLoader] used by this builder. + */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) @@ -36,16 +41,41 @@ internal class AsmBuilder internal constructor( private val T_CLASS: String = classOfT.name.replace('.', '/') private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + + /** + * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. + */ private val invokeThisVar: Int = 0 + + /** + * Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass. + */ private val invokeArgumentsVar: Int = 1 + + /** + * List of constants to provide to [AsmCompiledExpression] subclass. + */ private val constants: MutableList = mutableListOf() + + /** + * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. + */ private lateinit var invokeMethodVisitor: MethodVisitor private var primitiveMode = false internal var primitiveType = OBJECT_CLASS internal var primitiveTypeSig = "L$OBJECT_CLASS;" internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" + + /** + * The cache of [AsmCompiledExpression] subclass built by this builder. + */ private var generatedInstance: AsmCompiledExpression? = null + /** + * Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters. + * + * The built instance is cached. + */ @Suppress("UNCHECKED_CAST") fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } @@ -187,8 +217,6 @@ internal class AsmBuilder internal constructor( visitEnd() } - File("dump.class").writeBytes(classWriter.toByteArray()) - val new = classLoader .defineClass(className, classWriter.toByteArray()) .constructors @@ -199,6 +227,9 @@ internal class AsmBuilder internal constructor( return new } + /** + * Loads a constant from + */ internal fun loadTConstant(value: T) { if (primitiveMode) { loadNumberConstant(value as Number) @@ -214,6 +245,9 @@ internal class AsmBuilder internal constructor( loadConstant(value as Any, T_CLASS) } + /** + * Loads an object constant [value] stored in [AsmCompiledExpression.constants] and casts it to [type]. + */ private fun unboxInPrimitiveMode() { if (!primitiveMode) return @@ -241,13 +275,17 @@ internal class AsmBuilder internal constructor( visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") visitLdcOrIntConstant(idx) visitGetObjectArrayElement() - visitCheckCast(type) - unboxInPrimitiveMode() + invokeMethodVisitor.visitCheckCast(type) } } private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) + /** + * Either loads a numeric constant [value] from [AsmCompiledExpression.constants] field or boxes a primitive + * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded + * from it). + */ private fun loadNumberConstant(value: Number) { val clazz = value.javaClass val c = clazz.name.replace('.', '/') @@ -255,9 +293,12 @@ internal class AsmBuilder internal constructor( if (sigLetter != null) { when (value) { + is Byte -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) + is Short -> invokeMethodVisitor.visitLdcOrIntConstant(value.toInt()) is Int -> invokeMethodVisitor.visitLdcOrIntConstant(value) is Double -> invokeMethodVisitor.visitLdcOrDoubleConstant(value) is Float -> invokeMethodVisitor.visitLdcOrFloatConstant(value) + is Long -> invokeMethodVisitor.visitLdcOrLongConstant(value) else -> invokeMethodVisitor.visitLdcInsn(value) } @@ -270,6 +311,9 @@ internal class AsmBuilder internal constructor( loadConstant(value, c) } + /** + * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. + */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { visitLoadObjectVar(invokeArgumentsVar) @@ -299,6 +343,9 @@ internal class AsmBuilder internal constructor( unboxInPrimitiveMode() } + /** + * Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided. + */ internal fun loadAlgebra() { loadThis() @@ -311,6 +358,13 @@ internal class AsmBuilder internal constructor( invokeMethodVisitor.visitCheckCast(T_ALGEBRA_CLASS) } + /** + * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is + * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interface. [loadAlgebra] should be + * called before the arguments and this operation. + * + * The result is casted to [T] automatically. + */ internal fun invokeAlgebraOperation( owner: String, method: String, @@ -323,9 +377,15 @@ internal class AsmBuilder internal constructor( unboxInPrimitiveMode() } + /** + * Writes a LDC Instruction with string constant provided. + */ internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal companion object { + /** + * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. + */ private val SIGNATURE_LETTERS: Map, String> by lazy { mapOf( java.lang.Byte::class.java to "B", @@ -348,6 +408,9 @@ internal class AsmBuilder internal constructor( ) } + /** + * Provides boxed number types values of which can be stored in JVM bytecode constant pool. + */ private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt index 0e113099a..a5236927c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt @@ -3,6 +3,13 @@ package scientifik.kmath.asm.internal import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra +/** + * [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for + * objects needed to implement the expression. + * + * @property algebra the algebra to delegate calls. + * @property constants the constants array to have persistent objects to reference in [invoke]. + */ internal abstract class AsmCompiledExpression internal constructor( @JvmField protected val algebra: Algebra, @JvmField protected val constants: Array diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt index 65652d72b..66bd039c3 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -2,7 +2,13 @@ package scientifik.kmath.asm.internal import scientifik.kmath.ast.MST -internal fun buildName(mst: MST, collision: Int = 0): String { +/** + * Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided. + * + * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there + * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. + */ +internal tailrec fun buildName(mst: MST, collision: Int = 0): String { val name = "scientifik.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" try { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 70fb3bd44..6ecce8833 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -11,6 +11,8 @@ internal fun MethodVisitor.visitLdcOrIntConstant(value: Int): Unit = when (value 3 -> visitInsn(ICONST_3) 4 -> visitInsn(ICONST_4) 5 -> visitInsn(ICONST_5) + in -128..127 -> visitIntInsn(BIPUSH, value) + in -32768..32767 -> visitIntInsn(SIPUSH, value) else -> visitLdcInsn(value) } @@ -20,6 +22,12 @@ internal fun MethodVisitor.visitLdcOrDoubleConstant(value: Double): Unit = when else -> visitLdcInsn(value) } +internal fun MethodVisitor.visitLdcOrLongConstant(value: Long): Unit = when (value) { + 0L -> visitInsn(LCONST_0) + 1L -> visitInsn(LCONST_1) + else -> visitLdcInsn(value) +} + internal fun MethodVisitor.visitLdcOrFloatConstant(value: Float): Unit = when (value) { 0f -> visitInsn(FCONST_0) 1f -> visitInsn(FCONST_1) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index f279c94f4..51d056f8c 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -3,21 +3,38 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.Opcodes import scientifik.kmath.operations.Algebra -private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") +private val methodNameAdapters: Map by lazy { + hashMapOf( + "+" to "add", + "*" to "multiply", + "/" to "divide" + ) +} +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity]. + * + * @return `true` if contains, else `false`. + */ internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context::class.java.methods.find { it.name == aName && it.parameters.size == arity } + context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } ?: return false return true } +/** + * Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts + * [AsmBuilder.invokeAlgebraOperation] of this method. + * + * @return `true` if contains, else `false`. + */ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { val aName = methodNameAdapters[name] ?: name - context::class.java.methods.find { it.name == aName && it.parameters.size == arity } + context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } ?: return false val owner = context::class.java.name.replace('.', '/') @@ -33,8 +50,7 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri owner = owner, method = aName, descriptor = sig, - opcode = Opcodes.INVOKEVIRTUAL, - isInterface = false + opcode = Opcodes.INVOKEVIRTUAL ) return true diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt index c8d6e8eb0..a8a26aa33 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -36,12 +36,10 @@ internal class FunctionalConstProductExpression( /** * A context class for [Expression] construction. + * + * @param algebra The algebra to provide for Expressions built. */ -interface FunctionalExpressionAlgebra> : ExpressionAlgebra> { - /** - * The algebra to provide for Expressions built. - */ - val algebra: A +abstract class FunctionalExpressionAlgebra>(val algebra: A) : ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. @@ -69,8 +67,8 @@ interface FunctionalExpressionAlgebra> : ExpressionAlgebra>(override val algebra: A) : - FunctionalExpressionAlgebra, Space> { +open class FunctionalExpressionSpace>(algebra: A) : + FunctionalExpressionAlgebra(algebra), Space> { override val zero: Expression get() = const(algebra.zero) @@ -98,7 +96,7 @@ open class FunctionalExpressionSpace>(override val algebra: A) : super.binaryOperation(operation, left, right) } -open class FunctionalExpressionRing(override val algebra: A) : FunctionalExpressionSpace(algebra), +open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), Ring> where A : Ring, A : NumericAlgebra { override val one: Expression get() = const(algebra.one) @@ -119,7 +117,7 @@ open class FunctionalExpressionRing(override val algebra: A) : FunctionalE super.binaryOperation(operation, left, right) } -open class FunctionalExpressionField(override val algebra: A) : +open class FunctionalExpressionField(algebra: A) : FunctionalExpressionRing(algebra), Field> where A : Field, A : NumericAlgebra { /** @@ -137,3 +135,12 @@ open class FunctionalExpressionField(override val algebra: A) : override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = super.binaryOperation(operation, left, right) } + +inline fun > A.expressionInSpace(block: FunctionalExpressionSpace.() -> Expression): Expression = + FunctionalExpressionSpace(this).block() + +inline fun > A.expressionInRing(block: FunctionalExpressionRing.() -> Expression): Expression = + FunctionalExpressionRing(this).block() + +inline fun > A.expressionInField(block: FunctionalExpressionField.() -> Expression): Expression = + FunctionalExpressionField(this).block() \ No newline at end of file