diff --git a/kmath-asm/build.gradle.kts b/kmath-asm/build.gradle.kts index ab4684cdc..4104349f8 100644 --- a/kmath-asm/build.gradle.kts +++ b/kmath-asm/build.gradle.kts @@ -4,6 +4,6 @@ plugins { dependencies { api(project(path = ":kmath-core")) - api("org.ow2.asm:asm:8.0.1") - api("org.ow2.asm:asm-commons:8.0.1") + implementation("org.ow2.asm:asm:8.0.1") + implementation("org.ow2.asm:asm-commons:8.0.1") } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 8060674e5..f196efffe 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -8,11 +8,30 @@ import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space -abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { - abstract fun evaluate(arguments: Map): T +abstract class AsmCompiledExpression internal constructor( + @JvmField val algebra: Algebra, + @JvmField val constants: MutableList +) : Expression { + abstract override fun invoke(arguments: Map): T } -class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra, private val className: String) { +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + +class AsmGenerationContext( + classOfT: Class<*>, + private val algebra: Algebra, + private val className: String +) { private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) } @@ -26,30 +45,23 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra - private lateinit var asmCompiledClassWriter: ClassWriter - private lateinit var evaluateMethodVisitor: MethodVisitor - private lateinit var evaluateL0: Label - private lateinit var evaluateL1: Label + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) + private val invokeMethodVisitor: MethodVisitor + private val invokeL0: Label + private lateinit var invokeL1: Label + private var generatedInstance: AsmCompiledExpression? = null init { - start() - } - - fun start() { - constants = mutableListOf() - asmCompiledClassWriter = ClassWriter(0) - maxStack = 0 - asmCompiledClassWriter.visit( V1_8, ACC_PUBLIC or ACC_FINAL or ACC_SUPER, slashesClassName, - "L$ASM_COMPILED_CLASS;", - ASM_COMPILED_CLASS, + "L$ASM_COMPILED_EXPRESSION_CLASS;", + ASM_COMPILED_EXPRESSION_CLASS, arrayOf() ) @@ -66,7 +78,7 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", false @@ -93,45 +105,48 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra;)L$T_CLASS;", null ) - evaluateMethodVisitor.run { + invokeMethodVisitor.run { visitCode() - evaluateL0 = Label() - visitLabel(evaluateL0) + invokeL0 = Label() + visitLabel(invokeL0) } } } + @PublishedApi @Suppress("UNCHECKED_CAST") - fun generate(): AsmCompiled { - evaluateMethodVisitor.run { + internal fun generate(): AsmCompiledExpression { + generatedInstance?.let { return it } + + invokeMethodVisitor.run { visitInsn(ARETURN) - evaluateL1 = Label() - visitLabel(evaluateL1) + invokeL1 = Label() + visitLabel(invokeL1) visitLocalVariable( "this", "L$slashesClassName;", T_CLASS, - evaluateL0, - evaluateL1, - evaluateThisVar + invokeL0, + invokeL1, + invokeThisVar ) visitLocalVariable( "arguments", "L$MAP_CLASS;", "L$MAP_CLASS;", - evaluateL0, - evaluateL1, - evaluateArgumentsVar + invokeL0, + invokeL1, + invokeArgumentsVar ) visitMaxs(maxStack + 1, 2) @@ -140,7 +155,7 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra + .newInstance(algebra, constants) as AsmCompiledExpression + + generatedInstance = new + return new } - fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) + internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) - fun visitLoadAnyFromConstants(value: Any, type: String) { + private fun visitLoadAnyFromConstants(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex maxStack++ - evaluateMethodVisitor.run { + invokeMethodVisitor.run { visitLoadThis() visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") visitLdcOrIConstInsn(idx) visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) - evaluateMethodVisitor.visitTypeInsn(CHECKCAST, type) + invokeMethodVisitor.visitTypeInsn(CHECKCAST, type) } } - private fun visitLoadThis(): Unit = evaluateMethodVisitor.visitVarInsn(ALOAD, evaluateThisVar) + private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar) - fun visitNumberConstant(value: Number) { + internal fun visitNumberConstant(value: Number) { maxStack++ val clazz = value.javaClass val c = clazz.name.replace('.', '/') @@ -204,22 +222,22 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra evaluateMethodVisitor.visitLdcOrIConstInsn(value) - is Double -> evaluateMethodVisitor.visitLdcOrDConstInsn(value) - is Float -> evaluateMethodVisitor.visitLdcOrFConstInsn(value) - else -> evaluateMethodVisitor.visitLdcInsn(value) + is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) + is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) + is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) + else -> invokeMethodVisitor.visitLdcInsn(value) } - evaluateMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + invokeMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) return } visitLoadAnyFromConstants(value, c) } - fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { + internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { maxStack += 2 - visitVarInsn(ALOAD, evaluateArgumentsVar) + visitVarInsn(ALOAD, invokeArgumentsVar) if (defaultValue != null) { visitLdcInsn(name) @@ -242,26 +260,26 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra( } } -open class AsmFunctionalExpressionSpace( - val space: Space, - one: T -) : Space>, +open class AsmExpressionSpace(space: Space) : Space>, ExpressionSpace> { - override val zero: AsmExpression = - AsmConstantExpression(space.zero) - - override fun const(value: T): AsmExpression = - AsmConstantExpression(value) - - override fun variable(name: String, default: T?): AsmExpression = - AsmVariableExpression( - name, - default - ) - - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmSumExpression(a, b) - - override fun multiply(a: AsmExpression, k: Number): AsmExpression = - AsmConstProductExpression(a, k) - - + override val zero: AsmExpression = AsmConstantExpression(space.zero) + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) operator fun AsmExpression.plus(arg: T) = this + const(arg) operator fun AsmExpression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: AsmExpression) = arg + this operator fun T.minus(arg: AsmExpression) = arg - this } -class AsmFunctionalExpressionField(val field: Field) : ExpressionField>, - AsmFunctionalExpressionSpace(field, field.one) { +class AsmExpressionField(private val field: Field) : ExpressionField>, + AsmExpressionSpace(field) { override val one: AsmExpression get() = const(this.field.one) - override fun const(value: Double): AsmExpression = const(field.run { one * value }) + override fun number(value: Number): AsmExpression = const(field.run { one * value }) override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(a, b) @@ -412,7 +412,6 @@ class AsmFunctionalExpressionField(val field: Field) : ExpressionField.times(arg: T) = this * const(arg) operator fun AsmExpression.div(arg: T) = this / const(arg) - operator fun T.times(arg: AsmExpression) = arg * this operator fun T.div(arg: AsmExpression) = arg / this } diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 9ab6f26c9..875080616 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -12,15 +12,12 @@ class AsmTest { arguments: Map, algebra: Algebra, clazz: Class<*> - ) { - assertEquals( - expectedValue, AsmGenerationContext( - clazz, - algebra, - "TestAsmCompiled" - ).also(expr::invoke).generate().evaluate(arguments) - ) - } + ): Unit = assertEquals( + expectedValue, AsmGenerationContext(clazz, algebra, "TestAsmCompiled") + .also(expr::invoke) + .generate() + .invoke(arguments) + ) @Suppress("UNCHECKED_CAST") private fun testDoubleExpressionValue( @@ -29,7 +26,7 @@ class AsmTest { arguments: Map, algebra: Algebra = RealField, clazz: Class = java.lang.Double::class.java as Class - ) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) + ): Unit = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) @Test fun testSum() = testDoubleExpressionValue( @@ -39,28 +36,28 @@ class AsmTest { ) @Test - fun testConst() = testDoubleExpressionValue( + fun testConst(): Unit = testDoubleExpressionValue( 123.0, AsmConstantExpression(123.0), mapOf() ) @Test - fun testDiv() = testDoubleExpressionValue( + fun testDiv(): Unit = testDoubleExpressionValue( 0.5, AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)), mapOf() ) @Test - fun testProduct() = testDoubleExpressionValue( + fun testProduct(): Unit = testDoubleExpressionValue( 25.0, AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")), mapOf("x" to 5.0) ) @Test - fun testCProduct() = testDoubleExpressionValue( + fun testCProduct(): Unit = testDoubleExpressionValue( 25.0, AsmConstProductExpression(AsmVariableExpression("x"), 5.0), mapOf("x" to 5.0) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 7b781e4e1..dae069879 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -76,7 +76,7 @@ interface ExpressionSpace : Space, ExpressionContext { } interface ExpressionField : Field, ExpressionSpace { - fun const(value: Double): E = one.times(value) + fun number(value: Number): E = one * value override fun produce(node: SyntaxTreeNode): E { if (node is BinaryNode) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index ad0acbc4a..dbc2eac1e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -27,55 +27,41 @@ internal class ProductExpression(val context: Ring, val first: Expression< context.multiply(first.invoke(arguments), second.invoke(arguments)) } -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : +internal class ConstProductExpression(val context: Space, val expr: Expression, val const: Number) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : +internal class DivExpression(val context: Field, val expr: Expression, val second: Expression) : Expression { override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) } -open class FunctionalExpressionSpace( - val space: Space, - one: T -) : Space>, ExpressionSpace> { - +open class FunctionalExpressionSpace(val space: Space) : Space>, ExpressionSpace> { override val zero: Expression = ConstantExpression(space.zero) override fun const(value: T): Expression = ConstantExpression(value) - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) - - - operator fun Expression.plus(arg: T) = this + const(arg) - operator fun Expression.minus(arg: T) = this - const(arg) - - operator fun T.plus(arg: Expression) = arg + this - operator fun T.minus(arg: Expression) = arg - this + override fun multiply(a: Expression, k: Number): Expression = ConstProductExpression(space, a, k) + operator fun Expression.plus(arg: T): Expression = this + const(arg) + operator fun Expression.minus(arg: T): Expression = this - const(arg) + operator fun T.plus(arg: Expression): Expression = arg + this + operator fun T.minus(arg: Expression): Expression = arg - this } open class FunctionalExpressionField( val field: Field -) : ExpressionField>, FunctionalExpressionSpace(field, field.one) { +) : ExpressionField>, FunctionalExpressionSpace(field) { override val one: Expression get() = const(this.field.one) - override fun const(value: Double): Expression = const(field.run { one*value}) - + override fun number(value: Number): Expression = const(field.run { one * value }) override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) - - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) - - operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - - operator fun T.times(arg: Expression) = arg * this - operator fun T.div(arg: Expression) = arg / this -} \ No newline at end of file + override fun divide(a: Expression, b: Expression): Expression = DivExpression(field, a, b) + operator fun Expression.times(arg: T): Expression = this * const(arg) + operator fun Expression.div(arg: T): Expression = this / const(arg) + operator fun T.times(arg: Expression): Expression = arg * this + operator fun T.div(arg: Expression): Expression = arg / this +}