diff --git a/kmath-ast/build.gradle.kts b/kmath-ast/build.gradle.kts index 65a1ac12c..9764fccc3 100644 --- a/kmath-ast/build.gradle.kts +++ b/kmath-ast/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("scientifik.mpp") } -repositories{ +repositories { maven("https://dl.bintray.com/hotkeytlt/maven") } @@ -14,13 +14,18 @@ kotlin.sourceSets { implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3") } } - jvmMain{ - dependencies{ + + jvmMain { + dependencies { implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3") + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") + implementation(kotlin("reflect")) } } - jsMain{ - dependencies{ + + jsMain { + dependencies { implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3") } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt new file mode 100644 index 000000000..69e49f8b6 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmBuilders.kt @@ -0,0 +1,56 @@ +package scientifik.kmath.asm + +import scientifik.kmath.asm.internal.AsmGenerationContext +import scientifik.kmath.ast.MST +import scientifik.kmath.ast.evaluate +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.* + +@PublishedApi +internal fun buildName(expression: AsmNode<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + +@PublishedApi +internal inline fun AsmNode.compile(algebra: Algebra): Expression { + val ctx = + AsmGenerationContext(T::class.java, algebra, buildName(this)) + compile(ctx) + return ctx.generate() +} + +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + block: E.() -> AsmNode +): Expression = expressionAlgebra.block().compile(expressionAlgebra.algebra) + +inline fun , E : AsmExpressionAlgebra> A.asm( + expressionAlgebra: E, + ast: MST +): Expression = asm(expressionAlgebra) { evaluate(ast) } + +inline fun A.asmSpace(block: AsmExpressionSpace.() -> AsmNode): Expression where A : NumericAlgebra, A : Space = + AsmExpressionSpace(this).let { it.block().compile(it.algebra) } + +inline fun A.asmSpace(ast: MST): Expression where A : NumericAlgebra, A : Space = + asmSpace { evaluate(ast) } + +inline fun A.asmRing(block: AsmExpressionRing.() -> AsmNode): Expression where A : NumericAlgebra, A : Ring = + AsmExpressionRing(this).let { it.block().compile(it.algebra) } + +inline fun A.asmRing(ast: MST): Expression where A : NumericAlgebra, A : Ring = + asmRing { evaluate(ast) } + +inline fun A.asmField(block: AsmExpressionField.() -> AsmNode): Expression where A : NumericAlgebra, A : Field = + AsmExpressionField(this).let { it.block().compile(it.algebra) } + +inline fun A.asmField(ast: MST): Expression where A : NumericAlgebra, A : Field = + asmRing { evaluate(ast) } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt new file mode 100644 index 000000000..8f16a2710 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt @@ -0,0 +1,265 @@ +package scientifik.kmath.asm + +import scientifik.kmath.asm.internal.AsmGenerationContext +import scientifik.kmath.asm.internal.hasSpecific +import scientifik.kmath.asm.internal.optimize +import scientifik.kmath.asm.internal.tryInvokeSpecific +import scientifik.kmath.expressions.Expression +import scientifik.kmath.expressions.ExpressionAlgebra +import scientifik.kmath.operations.* + +/** + * A function declaration that could be compiled to [AsmGenerationContext]. + * + * @param T the type the stored function returns. + */ +abstract class AsmNode internal constructor() { + /** + * Tries to evaluate this function without its variables. This method is intended for optimization. + * + * @return `null` if the function depends on its variables, the value if the function is a constant. + */ + internal open fun tryEvaluate(): T? = null + + /** + * Compiles this declaration. + * + * @param gen the target [AsmGenerationContext]. + */ + @PublishedApi + internal abstract fun compile(gen: AsmGenerationContext) +} + +internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmNode) : + AsmNode() { + private val expr: AsmNode = expr.optimize() + override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) } + + override fun compile(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + + if (!hasSpecific(context, name, 1)) + gen.visitStringConstant(name) + + expr.compile(gen) + + if (gen.tryInvokeSpecific(context, name, 1)) + return + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.ALGEBRA_CLASS, + method = "unaryOperation", + descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmBinaryOperation( + private val context: Algebra, + private val name: String, + first: AsmNode, + second: AsmNode +) : AsmNode() { + private val first: AsmNode = first.optimize() + private val second: AsmNode = second.optimize() + + override fun tryEvaluate(): T? = context { + binaryOperation( + name, + first.tryEvaluate() ?: return@context null, + second.tryEvaluate() ?: return@context null + ) + } + + override fun compile(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + + if (!hasSpecific(context, name, 2)) + gen.visitStringConstant(name) + + first.compile(gen) + second.compile(gen) + + if (gen.tryInvokeSpecific(context, name, 2)) + return + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.ALGEBRA_CLASS, + method = "binaryOperation", + descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmVariableExpression(private val name: String, private val default: T? = null) : + AsmNode() { + override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) +} + +internal class AsmConstantExpression(private val value: T) : + AsmNode() { + override fun tryEvaluate(): T = value + override fun compile(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) +} + +internal class AsmConstProductExpression( + private val context: Space, + expr: AsmNode, + private val const: Number +) : AsmNode() { + private val expr: AsmNode = expr.optimize() + + override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } + + override fun compile(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitNumberConstant(const) + expr.compile(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + method = "multiply", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};" + + "L${AsmGenerationContext.NUMBER_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmNumberExpression(private val context: NumericAlgebra, private val value: Number) : + AsmNode() { + override fun tryEvaluate(): T? = context.number(value) + + override fun compile(gen: AsmGenerationContext): Unit = gen.visitNumberConstant(value) +} + +internal abstract class FunctionalCompiledExpression internal constructor( + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: Array +) : Expression { + abstract override fun invoke(arguments: Map): T +} + +/** + * A context class for [AsmNode] construction. + */ +interface AsmExpressionAlgebra> : NumericAlgebra>, + ExpressionAlgebra> { + /** + * The algebra to provide for AsmExpressions built. + */ + val algebra: A + + /** + * Builds an AsmExpression to wrap a number. + */ + override fun number(value: Number): AsmNode = AsmNumberExpression(algebra, value) + + /** + * Builds an AsmExpression of constant expression which does not depend on arguments. + */ + override fun const(value: T): AsmNode = AsmConstantExpression(value) + + /** + * Builds an AsmExpression to access a variable. + */ + override fun variable(name: String, default: T?): AsmNode = AsmVariableExpression(name, default) + + /** + * Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right]. + */ + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + AsmBinaryOperation(algebra, operation, left, right) + + /** + * Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg]. + */ + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + AsmUnaryOperation(algebra, operation, arg) +} + +/** + * A context class for [AsmNode] construction for [Space] algebras. + */ +open class AsmExpressionSpace(override val algebra: A) : AsmExpressionAlgebra, + Space> where A : Space, A : NumericAlgebra { + override val zero: AsmNode + get() = const(algebra.zero) + + /** + * Builds an AsmExpression of addition of two another expressions. + */ + override fun add(a: AsmNode, b: AsmNode): AsmNode = + AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + + /** + * Builds an AsmExpression of multiplication of expression by number. + */ + override fun multiply(a: AsmNode, k: Number): AsmNode = AsmConstProductExpression(algebra, a, k) + + operator fun AsmNode.plus(arg: T): AsmNode = this + const(arg) + operator fun AsmNode.minus(arg: T): AsmNode = this - const(arg) + operator fun T.plus(arg: AsmNode): AsmNode = arg + this + operator fun T.minus(arg: AsmNode): AsmNode = arg - this + + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + super.binaryOperation(operation, left, right) +} + +/** + * A context class for [AsmNode] construction for [Ring] algebras. + */ +open class AsmExpressionRing(override val algebra: A) : AsmExpressionSpace(algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: AsmNode + get() = const(algebra.one) + + /** + * Builds an AsmExpression of multiplication of two expressions. + */ + override fun multiply(a: AsmNode, b: AsmNode): AsmNode = + AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) + + operator fun AsmNode.times(arg: T): AsmNode = this * const(arg) + operator fun T.times(arg: AsmNode): AsmNode = arg * this + + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + super.binaryOperation(operation, left, right) + + override fun number(value: Number): AsmNode = super.number(value) +} + +/** + * A context class for [AsmNode] construction for [Field] algebras. + */ +open class AsmExpressionField(override val algebra: A) : + AsmExpressionRing(algebra), + Field> where A : Field, A : NumericAlgebra { + /** + * Builds an AsmExpression of division an expression by another one. + */ + override fun divide(a: AsmNode, b: AsmNode): AsmNode = + AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) + + operator fun AsmNode.div(arg: T): AsmNode = this / const(arg) + operator fun T.div(arg: AsmNode): AsmNode = arg / this + + override fun unaryOperation(operation: String, arg: AsmNode): AsmNode = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: AsmNode, right: AsmNode): AsmNode = + super.binaryOperation(operation, left, right) + + override fun number(value: Number): AsmNode = super.number(value) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt new file mode 100644 index 000000000..fa0f06579 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmGenerationContext.kt @@ -0,0 +1,319 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes +import scientifik.kmath.asm.FunctionalCompiledExpression +import scientifik.kmath.asm.internal.AsmGenerationContext.ClassLoader +import scientifik.kmath.operations.Algebra + +/** + * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java + * expression. This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new + * class. + * + * @param T the type of AsmExpression to unwrap. + * @param algebra the algebra the applied AsmExpressions use. + * @param className the unique class name of new loaded class. + */ +@PublishedApi +internal class AsmGenerationContext @PublishedApi internal constructor( + private val classOfT: Class<*>, + private val algebra: Algebra, + private val className: String +) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + private val classLoader: ClassLoader = + ClassLoader(javaClass.classLoader) + + @Suppress("PrivatePropertyName") + private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + + @Suppress("PrivatePropertyName") + private val T_CLASS: String = classOfT.name.replace('.', '/') + + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + private val invokeThisVar: Int = 0 + private val invokeArgumentsVar: Int = 1 + private var maxStack: Int = 0 + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) + private val invokeMethodVisitor: MethodVisitor + private val invokeL0: Label + private lateinit var invokeL1: Label + private var generatedInstance: FunctionalCompiledExpression? = null + + init { + asmCompiledClassWriter.visit( + Opcodes.V1_8, + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + slashesClassName, + "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + arrayOf() + ) + + asmCompiledClassWriter.run { + visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", null, null).run { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = Label() + visitLabel(l0) + visitVarInsn(Opcodes.ALOAD, thisVar) + visitVarInsn(Opcodes.ALOAD, algebraVar) + visitVarInsn(Opcodes.ALOAD, constantsVar) + + visitMethodInsn( + Opcodes.INVOKESPECIAL, + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, + "", + "(L$ALGEBRA_CLASS;[L$OBJECT_CLASS;)V", + false + ) + + val l1 = Label() + visitLabel(l1) + visitInsn(Opcodes.RETURN) + val l2 = Label() + visitLabel(l2) + visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + + visitLocalVariable( + "algebra", + "L$ALGEBRA_CLASS;", + "L$ALGEBRA_CLASS;", + l0, + l2, + algebraVar + ) + + visitLocalVariable("constants", "[L$OBJECT_CLASS;", null, l0, l2, constantsVar) + visitMaxs(3, 3) + visitEnd() + } + + invokeMethodVisitor = visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + "(L$MAP_CLASS;)L$T_CLASS;", + "(L$MAP_CLASS;)L$T_CLASS;", + null + ) + + invokeMethodVisitor.run { + visitCode() + invokeL0 = Label() + visitLabel(invokeL0) + } + } + } + + @PublishedApi + @Suppress("UNCHECKED_CAST") + internal fun generate(): FunctionalCompiledExpression { + generatedInstance?.let { return it } + + invokeMethodVisitor.run { + visitInsn(Opcodes.ARETURN) + invokeL1 = Label() + visitLabel(invokeL1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + invokeL0, + invokeL1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + invokeL0, + invokeL1, + invokeArgumentsVar + ) + + visitMaxs(maxStack + 1, 2) + visitEnd() + } + + asmCompiledClassWriter.visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + "(L$MAP_CLASS;)L$OBJECT_CLASS;", + null, + null + ).run { + val thisVar = 0 + visitCode() + val l0 = Label() + visitLabel(l0) + visitVarInsn(Opcodes.ALOAD, 0) + visitVarInsn(Opcodes.ALOAD, 1) + visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) + visitInsn(Opcodes.ARETURN) + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) + + visitMaxs(2, 2) + visitEnd() + } + + asmCompiledClassWriter.visitEnd() + + val new = classLoader + .defineClass(className, asmCompiledClassWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression + + generatedInstance = new + return new + } + + internal fun visitLoadFromConstants(value: T) { + if (classOfT in INLINABLE_NUMBERS) { + visitNumberConstant(value as Number) + visitCastToT() + return + } + + visitLoadAnyFromConstants(value as Any, T_CLASS) + } + + private fun visitLoadAnyFromConstants(value: Any, type: String) { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + maxStack++ + + invokeMethodVisitor.run { + visitLoadThis() + visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "[L$OBJECT_CLASS;") + visitLdcOrIConstInsn(idx) + visitInsn(Opcodes.AALOAD) + invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) + } + } + + private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + + internal fun visitNumberConstant(value: Number) { + maxStack++ + val clazz = value.javaClass + val c = clazz.name.replace('.', '/') + val sigLetter = SIGNATURE_LETTERS[clazz] + + if (sigLetter != null) { + when (value) { + is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) + is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) + is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) + else -> invokeMethodVisitor.visitLdcInsn(value) + } + + invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + return + } + + visitLoadAnyFromConstants(value, c) + } + + internal fun visitLoadFromVariables(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { + maxStack += 2 + visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) + + if (defaultValue != null) { + visitLdcInsn(name) + visitLoadFromConstants(defaultValue) + + visitMethodInsn( + Opcodes.INVOKEINTERFACE, + MAP_CLASS, + "getOrDefault", + "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", + true + ) + + visitCastToT() + return + } + + visitLdcInsn(name) + visitMethodInsn( + Opcodes.INVOKEINTERFACE, + MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true + ) + visitCastToT() + } + + internal fun visitLoadAlgebra() { + maxStack++ + invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + + invokeMethodVisitor.visitFieldInsn( + Opcodes.GETFIELD, + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" + ) + + invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) + } + + internal fun visitAlgebraOperation( + owner: String, + method: String, + descriptor: String, + opcode: Int = Opcodes.INVOKEINTERFACE, + isInterface: Boolean = true + ) { + maxStack++ + invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) + visitCastToT() + } + + private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) + + internal fun visitStringConstant(string: String) { + invokeMethodVisitor.visitLdcInsn(string) + } + + internal companion object { + private val SIGNATURE_LETTERS by lazy { + mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) + } + + private val INLINABLE_NUMBERS by lazy { SIGNATURE_LETTERS.keys } + + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = + "scientifik/kmath/asm/FunctionalCompiledExpression" + + internal const val MAP_CLASS = "java/util/Map" + internal const val OBJECT_CLASS = "java/lang/Object" + internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + internal const val STRING_CLASS = "java/lang/String" + internal const val NUMBER_CLASS = "java/lang/Number" + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt new file mode 100644 index 000000000..356de4765 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt @@ -0,0 +1,28 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes.* + +internal fun MethodVisitor.visitLdcOrIConstInsn(value: Int) = when (value) { + -1 -> visitInsn(ICONST_M1) + 0 -> visitInsn(ICONST_0) + 1 -> visitInsn(ICONST_1) + 2 -> visitInsn(ICONST_2) + 3 -> visitInsn(ICONST_3) + 4 -> visitInsn(ICONST_4) + 5 -> visitInsn(ICONST_5) + else -> visitLdcInsn(value) +} + +internal fun MethodVisitor.visitLdcOrDConstInsn(value: Double) = when (value) { + 0.0 -> visitInsn(DCONST_0) + 1.0 -> visitInsn(DCONST_1) + else -> visitLdcInsn(value) +} + +internal fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) { + 0f -> visitInsn(FCONST_0) + 1f -> visitInsn(FCONST_1) + 2f -> visitInsn(FCONST_2) + else -> visitLdcInsn(value) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt new file mode 100644 index 000000000..815f292ce --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt @@ -0,0 +1,49 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Opcodes +import scientifik.kmath.asm.AsmConstantExpression +import scientifik.kmath.asm.AsmNode +import scientifik.kmath.operations.Algebra + +private val methodNameAdapters: Map = mapOf("+" to "add", "*" to "multiply", "/" to "divide") + +internal fun hasSpecific(context: Algebra, name: String, arity: Int): Boolean { + val aName = methodNameAdapters[name] ?: name + + context::class.java.methods.find { it.name == aName && it.parameters.size == arity } + ?: return false + + return true +} + +internal fun AsmGenerationContext.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { + val aName = methodNameAdapters[name] ?: name + + context::class.java.methods.find { it.name == aName && it.parameters.size == arity } + ?: return false + + val owner = context::class.java.name.replace('.', '/') + + val sig = buildString { + append('(') + repeat(arity) { append("L${AsmGenerationContext.OBJECT_CLASS};") } + append(')') + append("L${AsmGenerationContext.OBJECT_CLASS};") + } + + visitAlgebraOperation( + owner = owner, + method = aName, + descriptor = sig, + opcode = Opcodes.INVOKEVIRTUAL, + isInterface = false + ) + + return true +} + +@PublishedApi +internal fun AsmNode.optimize(): AsmNode { + val a = tryEvaluate() + return if (a == null) this else AsmConstantExpression(a) +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt deleted file mode 100644 index c01648fb0..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/ast/asm.kt +++ /dev/null @@ -1,23 +0,0 @@ -package scientifik.kmath.ast - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.NumericAlgebra - -//TODO stubs for asm generation - -interface AsmExpression - -interface AsmExpressionAlgebra> : NumericAlgebra> { - val algebra: A -} - -fun AsmExpression.compile(): Expression = TODO() - -//TODO add converter for functional expressions - -inline fun > A.asm( - block: AsmExpressionAlgebra.() -> AsmExpression -): Expression = TODO() - -inline fun > A.asm(ast: MST): Expression = asm { evaluate(ast) } \ No newline at end of file diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt new file mode 100644 index 000000000..c92524b5d --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -0,0 +1,18 @@ +package scietifik.kmath.ast + +import scientifik.kmath.asm.asmField +import scientifik.kmath.ast.parseMath +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.Complex +import scientifik.kmath.operations.ComplexField +import kotlin.test.assertEquals +import kotlin.test.Test + +class AsmTest { + @Test + fun parsedExpression() { + val mst = "2+2*(2+2)".parseMath() + val res = ComplexField.asmField(mst)() + assertEquals(Complex(10.0, 0.0), res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt index 6849d24b8..06546302e 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/ParserTest.kt @@ -1,17 +1,17 @@ package scietifik.kmath.ast -import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.Test import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.parseMath import scientifik.kmath.operations.Complex import scientifik.kmath.operations.ComplexField +import kotlin.test.assertEquals +import kotlin.test.Test -internal class ParserTest{ +internal class ParserTest { @Test - fun parsedExpression(){ + fun parsedExpression() { val mst = "2+2*(2+2)".parseMath() val res = ComplexField.evaluate(mst) - assertEquals(Complex(10.0,0.0), res) + assertEquals(Complex(10.0, 0.0), res) } -} \ No newline at end of file +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt new file mode 100644 index 000000000..07f895c40 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmAlgebras.kt @@ -0,0 +1,75 @@ +package scietifik.kmath.ast.asm + +import scientifik.kmath.asm.asmField +import scientifik.kmath.asm.asmRing +import scientifik.kmath.asm.asmSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmAlgebras { + @Test + fun space() { + val res = ByteRing.asmSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + 3.toByte() - (2.toByte() + (multiply( + add(const(1), const(1)), + 2 + ) + 1.toByte()) * 3.toByte() - 1.toByte()) + ), + + number(1) + ) + variable("x") + zero + }("x" to 2.toByte()) + + assertEquals(16, res) + } + + @Test + fun ring() { + val res = ByteRing.asmRing { + binaryOperation( + "+", + + unaryOperation( + "+", + (3.toByte() - (2.toByte() + (multiply( + add(const(1), const(1)), + 2 + ) + 1.toByte()))) * 3.0 - 1.toByte() + ), + + number(1) + ) * const(2) + }() + + assertEquals(24, res) + } + + @Test + fun field() { + val res = RealField.asmField { + divide(binaryOperation( + "+", + + unaryOperation( + "+", + (3.0 - (2.0 + (multiply( + add(const(1.0), const(1.0)), + 2 + ) + 1.0))) * 3 - 1.0 + ), + + number(1) + ) / 2, const(2.0)) * one + }() + + assertEquals(3.0, res) + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt new file mode 100644 index 000000000..2839b8a25 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/asm/TestAsmExpressions.kt @@ -0,0 +1,21 @@ +package scietifik.kmath.ast.asm + +import scientifik.kmath.asm.asmField +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmExpressions { + @Test + fun testUnaryOperationInvocation() { + val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0) + assertEquals(2.0, res) + } + + @Test + fun testConstProductInvocation() { + val res = RealField.asmField { variable("x") * 2 }("x" to 2.0) + assertEquals(4.0, res) + } +} diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 092f3deb7..18c0cc771 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -8,4 +8,4 @@ kotlin.sourceSets { api(project(":kmath-memory")) } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt new file mode 100644 index 000000000..cdd6f695b --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Builders.kt @@ -0,0 +1,23 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring +import scientifik.kmath.operations.Space + +/** + * Create a functional expression on this [Space] + */ +fun Space.buildExpression(block: FunctionalExpressionSpace>.() -> Expression): Expression = + FunctionalExpressionSpace(this).run(block) + +/** + * Create a functional expression on this [Ring] + */ +fun Ring.buildExpression(block: FunctionalExpressionRing>.() -> Expression): Expression = + FunctionalExpressionRing(this).run(block) + +/** + * Create a functional expression on this [Field] + */ +fun Field.buildExpression(block: FunctionalExpressionField>.() -> Expression): Expression = + FunctionalExpressionField(this).run(block) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt new file mode 100644 index 000000000..a441a80c0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -0,0 +1,139 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.* + +internal class FunctionalUnaryOperation(val context: Algebra, val name: String, private val expr: Expression) : + Expression { + override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) +} + +internal class FunctionalBinaryOperation( + val context: Algebra, + val name: String, + val first: Expression, + val second: Expression +) : Expression { + override fun invoke(arguments: Map): T = + context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) +} + +internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { + override fun invoke(arguments: Map): T = + arguments[name] ?: default ?: error("Parameter not found: $name") +} + +internal class FunctionalConstantExpression(val value: T) : Expression { + override fun invoke(arguments: Map): T = value +} + +internal class FunctionalConstProductExpression( + val context: Space, + private val expr: Expression, + val const: Number +) : Expression { + override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) +} + +/** + * A context class for [Expression] construction. + */ +interface FunctionalExpressionAlgebra> : ExpressionAlgebra> { + /** + * The algebra to provide for Expressions built. + */ + val algebra: A + + /** + * Builds an Expression of constant expression which does not depend on arguments. + */ + override fun const(value: T): Expression = FunctionalConstantExpression(value) + + /** + * Builds an Expression to access a variable. + */ + override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) + + /** + * Builds an Expression of dynamic call of binary operation [operation] on [left] and [right]. + */ + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + /** + * Builds an Expression of dynamic call of unary operation with name [operation] on [arg]. + */ + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) +} + +/** + * A context class for [Expression] construction for [Space] algebras. + */ +open class FunctionalExpressionSpace(override val algebra: A) : FunctionalExpressionAlgebra, + Space> where A : Space { + override val zero: Expression + get() = const(algebra.zero) + + /** + * Builds an Expression of addition of two another expressions. + */ + override fun add(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b) + + /** + * Builds an Expression of multiplication of expression by number. + */ + override fun multiply(a: Expression, k: Number): Expression = + FunctionalConstProductExpression(algebra, a, k) + + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionRing(override val algebra: A) : FunctionalExpressionSpace(algebra), + Ring> where A : Ring, A : NumericAlgebra { + override val one: Expression + get() = const(algebra.one) + + /** + * Builds an Expression of multiplication of two expressions. + */ + override fun multiply(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b) + + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun T.times(arg: Expression): Expression = arg * this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} + +open class FunctionalExpressionField(override val algebra: A) : + FunctionalExpressionRing(algebra), + Field> where A : Field, A : NumericAlgebra { + /** + * Builds an Expression of division an expression by another one. + */ + override fun divide(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b) + + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.div(arg: Expression): Expression = arg / this + + override fun unaryOperation(operation: String, arg: Expression): Expression = + super.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + super.binaryOperation(operation, left, right) +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt deleted file mode 100644 index 6d7676c25..000000000 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt +++ /dev/null @@ -1,94 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space - -internal class VariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T = - arguments[name] ?: default ?: error("Parameter not found: $name") -} - -internal class ConstantExpression(val value: T) : Expression { - override fun invoke(arguments: Map): T = value -} - -internal class SumExpression( - val context: Space, - val first: Expression, - val second: Expression -) : Expression { - override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ProductExpression( - val context: Ring, - val first: Expression, - val second: Expression -) : Expression { - override fun invoke(arguments: Map): T = - context.multiply(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : - Expression { - override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) -} - -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) -} - -open class FunctionalExpressionSpace( - val space: Space -) : Space>, ExpressionAlgebra> { - - override val zero: Expression = ConstantExpression(space.zero) - - override fun const(value: T): Expression = ConstantExpression(value) - - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) - - - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this -} - -open class FunctionalExpressionField( - val field: Field -) : Field>, ExpressionAlgebra>, FunctionalExpressionSpace(field) { - - override val one: Expression = ConstantExpression(this.field.one) - - fun const(value: Double): Expression = const(field.run { one * value }) - - override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) - - operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - - operator fun T.times(arg: Expression) = arg * this - operator fun T.div(arg: Expression) = arg / this -} - -/** - * Create a functional expression on this [Space] - */ -fun Space.buildExpression(block: FunctionalExpressionSpace.() -> Expression): Expression = - FunctionalExpressionSpace(this).run(block) - -/** - * Create a functional expression on this [Field] - */ -fun Field.buildExpression(block: FunctionalExpressionField.() -> Expression): Expression = - FunctionalExpressionField(this).run(block) \ No newline at end of file diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt similarity index 100% rename from kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt rename to kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt