diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts new file mode 100644 index 000000000..ab4684cdc --- /dev/null +++ b/kmath-asm/build.gradle.kts @@ -0,0 +1,9 @@ +plugins { + id("scientifik.jvm") +} + +dependencies { + api(project(path = ":kmath-core")) + api("org.ow2.asm:asm:8.0.1") + api("org.ow2.asm:asm-commons:8.0.1") +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt new file mode 100644 index 000000000..b9c357671 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -0,0 +1,385 @@ +package scientifik.kmath.expressions + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes.* +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space +import java.io.File + +abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { + abstract fun evaluate(arguments: Map): T +} + +class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra, private val className: String) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + @Suppress("PrivatePropertyName") + private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + + @Suppress("PrivatePropertyName") + private val T_CLASS: String = classOfT.name.replace('.', '/') + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter = ClassWriter(0) + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + private val evaluateMethodVisitor: MethodVisitor + private val evaluateThisVar: Int = 0 + private val evaluateArgumentsVar: Int = 1 + private var evaluateL0: Label + private lateinit var evaluateL1: Label + var maxStack: Int = 0 + + init { + asmCompiledClassWriter.visit( + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, + slashesClassName, + "L$ASM_COMPILED_CLASS;", + ASM_COMPILED_CLASS, + arrayOf() + ) + + asmCompiledClassWriter.run { + visitMethod(ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = Label() + visitLabel(l0) + visitVarInsn(ALOAD, thisVar) + visitVarInsn(ALOAD, algebraVar) + visitVarInsn(ALOAD, constantsVar) + + visitMethodInsn( + INVOKESPECIAL, + ASM_COMPILED_CLASS, + "", + "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", + false + ) + + val l1 = Label() + visitLabel(l1) + visitInsn(RETURN) + val l2 = Label() + visitLabel(l2) + visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + + visitLocalVariable( + "algebra", + "L$ALGEBRA_CLASS;", + "L$ALGEBRA_CLASS;", + l0, + l2, + algebraVar + ) + + visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) + visitMaxs(3, 3) + visitEnd() + } + + evaluateMethodVisitor = visitMethod( + ACC_PUBLIC or ACC_FINAL, + "evaluate", + "(L$MAP_CLASS;)L$T_CLASS;", + "(L$MAP_CLASS;)L$T_CLASS;", + null + ) + + evaluateMethodVisitor.run { + visitCode() + evaluateL0 = Label() + visitLabel(evaluateL0) + } + } + } + + @Suppress("UNCHECKED_CAST") + fun generate(): AsmCompiled { + evaluateMethodVisitor.run { + visitInsn(ARETURN) + evaluateL1 = Label() + visitLabel(evaluateL1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + evaluateL0, + evaluateL1, + evaluateThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + evaluateL0, + evaluateL1, + evaluateArgumentsVar + ) + + visitMaxs(maxStack + 1, 2) + visitEnd() + } + + asmCompiledClassWriter.visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + "evaluate", + "(L$MAP_CLASS;)L$OBJECT_CLASS;", + null, + null + ).run { + val thisVar = 0 + visitCode() + val l0 = Label() + visitLabel(l0) + visitVarInsn(ALOAD, 0) + visitVarInsn(ALOAD, 1) + visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "evaluate", "(L$MAP_CLASS;)L$T_CLASS;", false) + visitInsn(ARETURN) + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) + + visitMaxs(2, 2) + visitEnd() + } + + asmCompiledClassWriter.visitEnd() + + return classLoader + .defineClass(className, asmCompiledClassWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants) as AsmCompiled + } + + fun visitLoadFromConstants(value: T) { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + maxStack++ + + evaluateMethodVisitor.run { + visitLoadThis() + visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") + visitLdcOrIConstInsn(idx) + visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) + visitCastToT() + } + } + + private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) + + fun visitNumberConstant(value: Number): Unit = evaluateMethodVisitor.visitLdcInsn(value) + + fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { + maxStack++ + visitVarInsn(ALOAD, evaluateArgumentsVar) + + if (defaultValue != null) { + visitLdcInsn(name) + visitLoadFromConstants(defaultValue) + + visitMethodInsn( + INVOKEINTERFACE, + MAP_CLASS, + "getOrDefault", + "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", + true + ) + + visitCastToT() + return + } + + visitLdcInsn(name) + visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) + visitCastToT() + } + + fun visitLoadAlgebra() { + evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) + + evaluateMethodVisitor.visitFieldInsn( + GETFIELD, + ASM_COMPILED_CLASS, "algebra", "L$ALGEBRA_CLASS;" + ) + + evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS) + } + + fun visitInvokeAlgebraOperation(owner: String, method: String, descriptor: String) { + maxStack++ + evaluateMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true) + visitCastToT() + } + + fun visitCastToT() { + evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS) + } + + companion object { + const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled" + const val LIST_CLASS = "java/util/List" + const val MAP_CLASS = "java/util/Map" + const val OBJECT_CLASS = "java/lang/Object" + const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + const val SPACE_CLASS = "scientifik/kmath/operations/Space" + const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + const val FIELD_CLASS = "scientifik/kmath/operations/Field" + const val STRING_CLASS = "java/lang/String" + } +} + +interface AsmExpression { + fun invoke(gen: AsmGenerationContext) +} + +internal class AsmVariableExpression(val name: String, val default: T? = null) : + AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadFromVariables(name, default) + } +} + +internal class AsmConstantExpression(val value: T) : + AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadFromConstants(value) + } +} + +internal class AsmSumExpression( + val first: AsmExpression, + val second: AsmExpression +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + first.invoke(gen) + second.invoke(gen) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + method = "add", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmProductExpression( + val first: AsmExpression, + val second: AsmExpression +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + first.invoke(gen) + second.invoke(gen) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.SPACE_CLASS, + method = "times", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmConstProductExpression( + val expr: AsmExpression, + val const: Number +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + expr.invoke(gen) + gen.visitNumberConstant(const) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.SPACE_CLASS, + method = "multiply", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmDivExpression( + val expr: AsmExpression, + val second: AsmExpression +) : AsmExpression { + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + expr.invoke(gen) + second.invoke(gen) + + gen.visitInvokeAlgebraOperation( + owner = AsmGenerationContext.FIELD_CLASS, + method = "divide", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + + +open class AsmFunctionalExpressionSpace( + val space: Space, + one: T +) : Space>, + ExpressionSpace> { + override val zero: AsmExpression = + AsmConstantExpression(space.zero) + + override fun const(value: T): AsmExpression = + AsmConstantExpression(value) + + override fun variable(name: String, default: T?): AsmExpression = + AsmVariableExpression( + name, + default + ) + + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmSumExpression(a, b) + + override fun multiply(a: AsmExpression, k: Number): AsmExpression = + AsmConstProductExpression(a, k) + + + operator fun AsmExpression.plus(arg: T) = this + const(arg) + operator fun AsmExpression.minus(arg: T) = this - const(arg) + + operator fun T.plus(arg: AsmExpression) = arg + this + operator fun T.minus(arg: AsmExpression) = arg - this +} + +class AsmFunctionalExpressionField(val field: Field) : ExpressionField>, + AsmFunctionalExpressionSpace(field, field.one) { + override val one: AsmExpression + get() = const(this.field.one) + + override fun const(value: Double): AsmExpression = const(field.run { one * value }) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmProductExpression(a, b) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmDivExpression(a, b) + + operator fun AsmExpression.times(arg: T) = this * const(arg) + operator fun AsmExpression.div(arg: T) = this / const(arg) + + operator fun T.times(arg: AsmExpression) = arg * this + operator fun T.div(arg: AsmExpression) = arg / this +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt new file mode 100644 index 000000000..fdbc1062e --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt @@ -0,0 +1,17 @@ +package scientifik.kmath.expressions + +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes + +fun MethodVisitor.visitLdcOrIConstInsn(value: Int) { + when (value) { + -1 -> visitInsn(Opcodes.ICONST_M1) + 0 -> visitInsn(Opcodes.ICONST_0) + 1 -> visitInsn(Opcodes.ICONST_1) + 2 -> visitInsn(Opcodes.ICONST_2) + 3 -> visitInsn(Opcodes.ICONST_3) + 4 -> visitInsn(Opcodes.ICONST_4) + 5 -> visitInsn(Opcodes.ICONST_5) + else -> visitLdcInsn(value) + } +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt new file mode 100644 index 000000000..e95f8df76 --- /dev/null +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -0,0 +1,23 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.RealField +import kotlin.test.Test +import kotlin.test.assertEquals + +class AsmTest { + @Test + fun test() { + val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")) + + val gen = AsmGenerationContext( + java.lang.Double::class.java, + RealField, + "MyAsmCompiled" + ) + + expr.invoke(gen) + val compiled = gen.generate() + val value = compiled.evaluate(mapOf("x" to 25.0)) + assertEquals(26.0, value) + } +} diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 092f3deb7..18c0cc771 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -8,4 +8,4 @@ kotlin.sourceSets { api(project(":kmath-memory")) } } -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 94bc81413..7b781e4e1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -15,7 +15,7 @@ operator fun Expression.invoke(vararg pairs: Pair): T = invoke /** * A context for expression construction */ -interface ExpressionContext> { +interface ExpressionContext { /** * Introduce a variable into expression context */ @@ -29,11 +29,11 @@ interface ExpressionContext> { fun produce(node: SyntaxTreeNode): E } -interface ExpressionSpace> : Space, ExpressionContext { +interface ExpressionSpace : Space, ExpressionContext { - open fun produceSingular(value: String): E = variable(value) + fun produceSingular(value: String): E = variable(value) - open fun produceUnary(operation: String, value: E): E { + fun produceUnary(operation: String, value: E): E { return when (operation) { UnaryNode.PLUS_OPERATION -> value UnaryNode.MINUS_OPERATION -> -value @@ -41,7 +41,7 @@ interface ExpressionSpace> : Space, ExpressionContext left + right BinaryNode.MINUS_OPERATION -> left - right @@ -75,7 +75,7 @@ interface ExpressionSpace> : Space, ExpressionContext> : Field, ExpressionSpace { +interface ExpressionField : Field, ExpressionSpace { fun const(value: Double): E = one.times(value) override fun produce(node: SyntaxTreeNode): E { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt similarity index 100% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/functionalExpressions.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt diff --git a/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt b/kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt similarity index 100% rename from kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/bigNumbers.kt rename to kmath-core/src/jvmMain/kotlin/scientifik/kmath/operations/BigNumbers.kt diff --git a/settings.gradle.kts b/settings.gradle.kts index 57173250b..fcaa72698 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -44,5 +44,6 @@ include( ":kmath-dimensions", ":kmath-for-real", ":kmath-geometry", - ":examples" + ":examples", + ":kmath-asm" )