diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt new file mode 100644 index 000000000..5c93fd729 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt @@ -0,0 +1,31 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space + +open class AsmExpressionSpace(space: Space) : Space>, + ExpressionSpace> { + override val zero: AsmExpression = AsmConstantExpression(space.zero) + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) + operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) + operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) + operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this + operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this +} + +class AsmExpressionField(private val field: Field) : ExpressionField>, + AsmExpressionSpace(field) { + override val one: AsmExpression + get() = const(this.field.one) + + override fun number(value: Number): AsmExpression = const(field.run { one * value }) + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = AsmProductExpression(a, b) + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(a, b) + operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) + operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) + operator fun T.times(arg: AsmExpression): AsmExpression = arg * this + operator fun T.div(arg: AsmExpression): AsmExpression = arg / this +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index f196efffe..8708a33e1 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -1,307 +1,14 @@ package scientifik.kmath.expressions -import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Label -import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes.* import scientifik.kmath.operations.Algebra -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space abstract class AsmCompiledExpression internal constructor( - @JvmField val algebra: Algebra, - @JvmField val constants: MutableList + @JvmField private val algebra: Algebra, + @JvmField private val constants: MutableList ) : Expression { abstract override fun invoke(arguments: Map): T } -internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { - val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" - - try { - Class.forName(name) - } catch (ignored: ClassNotFoundException) { - return name - } - - return buildName(expression, collision + 1) -} - -class AsmGenerationContext( - classOfT: Class<*>, - private val algebra: Algebra, - private val className: String -) { - private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { - internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) - } - - private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') - - @Suppress("PrivatePropertyName") - private val T_CLASS: String = classOfT.name.replace('.', '/') - - private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') - private val invokeThisVar: Int = 0 - private val invokeArgumentsVar: Int = 1 - private var maxStack: Int = 0 - private val constants: MutableList = mutableListOf() - private val asmCompiledClassWriter: ClassWriter = ClassWriter(0) - private val invokeMethodVisitor: MethodVisitor - private val invokeL0: Label - private lateinit var invokeL1: Label - private var generatedInstance: AsmCompiledExpression? = null - - init { - asmCompiledClassWriter.visit( - V1_8, - ACC_PUBLIC or ACC_FINAL or ACC_SUPER, - slashesClassName, - "L$ASM_COMPILED_EXPRESSION_CLASS;", - ASM_COMPILED_EXPRESSION_CLASS, - arrayOf() - ) - - asmCompiledClassWriter.run { - visitMethod(ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { - val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - visitVarInsn(ALOAD, thisVar) - visitVarInsn(ALOAD, algebraVar) - visitVarInsn(ALOAD, constantsVar) - - visitMethodInsn( - INVOKESPECIAL, - ASM_COMPILED_EXPRESSION_CLASS, - "", - "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", - false - ) - - val l1 = Label() - visitLabel(l1) - visitInsn(RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) - - visitLocalVariable( - "algebra", - "L$ALGEBRA_CLASS;", - "L$ALGEBRA_CLASS;", - l0, - l2, - algebraVar - ) - - visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) - visitMaxs(3, 3) - visitEnd() - } - - invokeMethodVisitor = visitMethod( - ACC_PUBLIC or ACC_FINAL, - "invoke", - "(L$MAP_CLASS;)L$T_CLASS;", - "(L$MAP_CLASS;)L$T_CLASS;", - null - ) - - invokeMethodVisitor.run { - visitCode() - invokeL0 = Label() - visitLabel(invokeL0) - } - } - } - - @PublishedApi - @Suppress("UNCHECKED_CAST") - internal fun generate(): AsmCompiledExpression { - generatedInstance?.let { return it } - - invokeMethodVisitor.run { - visitInsn(ARETURN) - invokeL1 = Label() - visitLabel(invokeL1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - invokeL0, - invokeL1, - invokeThisVar - ) - - visitLocalVariable( - "arguments", - "L$MAP_CLASS;", - "L$MAP_CLASS;", - invokeL0, - invokeL1, - invokeArgumentsVar - ) - - visitMaxs(maxStack + 1, 2) - visitEnd() - } - - asmCompiledClassWriter.visitMethod( - ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, - "invoke", - "(L$MAP_CLASS;)L$OBJECT_CLASS;", - null, - null - ).run { - val thisVar = 0 - visitCode() - val l0 = Label() - visitLabel(l0) - visitVarInsn(ALOAD, 0) - visitVarInsn(ALOAD, 1) - visitMethodInsn(INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) - visitInsn(ARETURN) - val l1 = Label() - visitLabel(l1) - - visitLocalVariable( - "this", - "L$slashesClassName;", - T_CLASS, - l0, - l1, - thisVar - ) - - visitMaxs(2, 2) - visitEnd() - } - - asmCompiledClassWriter.visitEnd() - - val new = classLoader - .defineClass(className, asmCompiledClassWriter.toByteArray()) - .constructors - .first() - .newInstance(algebra, constants) as AsmCompiledExpression - - generatedInstance = new - return new - } - - internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) - - private fun visitLoadAnyFromConstants(value: Any, type: String) { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - maxStack++ - - invokeMethodVisitor.run { - visitLoadThis() - visitFieldInsn(GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") - visitLdcOrIConstInsn(idx) - visitMethodInsn(INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) - invokeMethodVisitor.visitTypeInsn(CHECKCAST, type) - } - } - - private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar) - - internal fun visitNumberConstant(value: Number) { - maxStack++ - val clazz = value.javaClass - val c = clazz.name.replace('.', '/') - - val sigLetter = SIGNATURE_LETTERS[clazz] - - if (sigLetter != null) { - when (value) { - is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) - is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) - is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) - else -> invokeMethodVisitor.visitLdcInsn(value) - } - - invokeMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) - return - } - - visitLoadAnyFromConstants(value, c) - } - - internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { - maxStack += 2 - visitVarInsn(ALOAD, invokeArgumentsVar) - - if (defaultValue != null) { - visitLdcInsn(name) - visitLoadFromConstants(defaultValue) - - visitMethodInsn( - INVOKEINTERFACE, - MAP_CLASS, - "getOrDefault", - "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", - true - ) - - visitCastToT() - return - } - - visitLdcInsn(name) - visitMethodInsn(INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) - visitCastToT() - } - - internal fun visitLoadAlgebra() { - invokeMethodVisitor.visitVarInsn(ALOAD, invokeThisVar) - - invokeMethodVisitor.visitFieldInsn( - GETFIELD, - ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" - ) - - invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_ALGEBRA_CLASS) - } - - internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) { - maxStack++ - invokeMethodVisitor.visitMethodInsn(INVOKEINTERFACE, owner, method, descriptor, true) - visitCastToT() - } - - private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS) - - internal companion object { - private val SIGNATURE_LETTERS = mapOf( - java.lang.Byte::class.java to "B", - java.lang.Short::class.java to "S", - java.lang.Integer::class.java to "I", - java.lang.Long::class.java to "J", - java.lang.Float::class.java to "F", - java.lang.Double::class.java to "D" - ) - - internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression" - internal const val LIST_CLASS = "java/util/List" - internal const val MAP_CLASS = "java/util/Map" - internal const val OBJECT_CLASS = "java/lang/Object" - internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" - internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" - internal const val STRING_CLASS = "java/lang/String" - internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" - internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations" - internal const val NUMBER_CLASS = "java/lang/Number" - } -} - interface AsmExpression { fun invoke(gen: AsmGenerationContext) } @@ -383,35 +90,3 @@ internal class AsmDivExpression( ) } } - -open class AsmExpressionSpace(space: Space) : Space>, - ExpressionSpace> { - override val zero: AsmExpression = AsmConstantExpression(space.zero) - override fun const(value: T): AsmExpression = AsmConstantExpression(value) - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(a, b) - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(a, k) - operator fun AsmExpression.plus(arg: T) = this + const(arg) - operator fun AsmExpression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: AsmExpression) = arg + this - operator fun T.minus(arg: AsmExpression) = arg - this -} - -class AsmExpressionField(private val field: Field) : ExpressionField>, - AsmExpressionSpace(field) { - override val one: AsmExpression - get() = const(this.field.one) - - override fun number(value: Number): AsmExpression = const(field.run { one * value }) - - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmProductExpression(a, b) - - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmDivExpression(a, b) - - operator fun AsmExpression.times(arg: T) = this * const(arg) - operator fun AsmExpression.div(arg: T) = this / const(arg) - operator fun T.times(arg: AsmExpression) = arg * this - operator fun T.div(arg: AsmExpression) = arg / this -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt new file mode 100644 index 000000000..cedd5c0fd --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt @@ -0,0 +1,283 @@ +package scientifik.kmath.expressions + +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.Label +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Opcodes +import scientifik.kmath.operations.Algebra + +class AsmGenerationContext( + classOfT: Class<*>, + private val algebra: Algebra, + private val className: String +) { + private class ClassLoader(parent: java.lang.ClassLoader) : java.lang.ClassLoader(parent) { + internal fun defineClass(name: String?, b: ByteArray): Class<*> = defineClass(name, b, 0, b.size) + } + + private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) + + @Suppress("PrivatePropertyName") + private val T_ALGEBRA_CLASS: String = algebra.javaClass.name.replace(oldChar = '.', newChar = '/') + + @Suppress("PrivatePropertyName") + private val T_CLASS: String = classOfT.name.replace('.', '/') + + private val slashesClassName: String = className.replace(oldChar = '.', newChar = '/') + private val invokeThisVar: Int = 0 + private val invokeArgumentsVar: Int = 1 + private var maxStack: Int = 0 + private val constants: MutableList = mutableListOf() + private val asmCompiledClassWriter: ClassWriter = + ClassWriter(0) + private val invokeMethodVisitor: MethodVisitor + private val invokeL0: Label + private lateinit var invokeL1: Label + private var generatedInstance: AsmCompiledExpression? = null + + init { + asmCompiledClassWriter.visit( + Opcodes.V1_8, + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + slashesClassName, + "L$ASM_COMPILED_EXPRESSION_CLASS;", + ASM_COMPILED_EXPRESSION_CLASS, + arrayOf() + ) + + asmCompiledClassWriter.run { + visitMethod(Opcodes.ACC_PUBLIC, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", null, null).run { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = Label() + visitLabel(l0) + visitVarInsn(Opcodes.ALOAD, thisVar) + visitVarInsn(Opcodes.ALOAD, algebraVar) + visitVarInsn(Opcodes.ALOAD, constantsVar) + + visitMethodInsn( + Opcodes.INVOKESPECIAL, + ASM_COMPILED_EXPRESSION_CLASS, + "", + "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", + false + ) + + val l1 = Label() + visitLabel(l1) + visitInsn(Opcodes.RETURN) + val l2 = Label() + visitLabel(l2) + visitLocalVariable("this", "L$slashesClassName;", null, l0, l2, thisVar) + + visitLocalVariable( + "algebra", + "L$ALGEBRA_CLASS;", + "L$ALGEBRA_CLASS;", + l0, + l2, + algebraVar + ) + + visitLocalVariable("constants", "L$LIST_CLASS;", "L$LIST_CLASS;", l0, l2, constantsVar) + visitMaxs(3, 3) + visitEnd() + } + + invokeMethodVisitor = visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, + "invoke", + "(L$MAP_CLASS;)L$T_CLASS;", + "(L$MAP_CLASS;)L$T_CLASS;", + null + ) + + invokeMethodVisitor.run { + visitCode() + invokeL0 = Label() + visitLabel(invokeL0) + } + } + } + + @PublishedApi + @Suppress("UNCHECKED_CAST") + internal fun generate(): AsmCompiledExpression { + generatedInstance?.let { return it } + + invokeMethodVisitor.run { + visitInsn(Opcodes.ARETURN) + invokeL1 = Label() + visitLabel(invokeL1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + invokeL0, + invokeL1, + invokeThisVar + ) + + visitLocalVariable( + "arguments", + "L$MAP_CLASS;", + "L$MAP_CLASS;", + invokeL0, + invokeL1, + invokeArgumentsVar + ) + + visitMaxs(maxStack + 1, 2) + visitEnd() + } + + asmCompiledClassWriter.visitMethod( + Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + "invoke", + "(L$MAP_CLASS;)L$OBJECT_CLASS;", + null, + null + ).run { + val thisVar = 0 + visitCode() + val l0 = Label() + visitLabel(l0) + visitVarInsn(Opcodes.ALOAD, 0) + visitVarInsn(Opcodes.ALOAD, 1) + visitMethodInsn(Opcodes.INVOKEVIRTUAL, slashesClassName, "invoke", "(L$MAP_CLASS;)L$T_CLASS;", false) + visitInsn(Opcodes.ARETURN) + val l1 = Label() + visitLabel(l1) + + visitLocalVariable( + "this", + "L$slashesClassName;", + T_CLASS, + l0, + l1, + thisVar + ) + + visitMaxs(2, 2) + visitEnd() + } + + asmCompiledClassWriter.visitEnd() + + val new = classLoader + .defineClass(className, asmCompiledClassWriter.toByteArray()) + .constructors + .first() + .newInstance(algebra, constants) as AsmCompiledExpression + + generatedInstance = new + return new + } + + internal fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any, T_CLASS) + + private fun visitLoadAnyFromConstants(value: Any, type: String) { + val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex + maxStack++ + + invokeMethodVisitor.run { + visitLoadThis() + visitFieldInsn(Opcodes.GETFIELD, slashesClassName, "constants", "L$LIST_CLASS;") + visitLdcOrIConstInsn(idx) + visitMethodInsn(Opcodes.INVOKEINTERFACE, LIST_CLASS, "get", "(I)L$OBJECT_CLASS;", true) + invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type) + } + } + + private fun visitLoadThis(): Unit = invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + + internal fun visitNumberConstant(value: Number) { + maxStack++ + val clazz = value.javaClass + val c = clazz.name.replace('.', '/') + + val sigLetter = SIGNATURE_LETTERS[clazz] + + if (sigLetter != null) { + when (value) { + is Int -> invokeMethodVisitor.visitLdcOrIConstInsn(value) + is Double -> invokeMethodVisitor.visitLdcOrDConstInsn(value) + is Float -> invokeMethodVisitor.visitLdcOrFConstInsn(value) + else -> invokeMethodVisitor.visitLdcInsn(value) + } + + invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false) + return + } + + visitLoadAnyFromConstants(value, c) + } + + internal fun visitLoadFromVariables(name: String, defaultValue: T? = null) = invokeMethodVisitor.run { + maxStack += 2 + visitVarInsn(Opcodes.ALOAD, invokeArgumentsVar) + + if (defaultValue != null) { + visitLdcInsn(name) + visitLoadFromConstants(defaultValue) + + visitMethodInsn( + Opcodes.INVOKEINTERFACE, + MAP_CLASS, + "getOrDefault", + "(L$OBJECT_CLASS;L$OBJECT_CLASS;)L$OBJECT_CLASS;", + true + ) + + visitCastToT() + return + } + + visitLdcInsn(name) + visitMethodInsn(Opcodes.INVOKEINTERFACE, MAP_CLASS, "get", "(L$OBJECT_CLASS;)L$OBJECT_CLASS;", true) + visitCastToT() + } + + internal fun visitLoadAlgebra() { + invokeMethodVisitor.visitVarInsn(Opcodes.ALOAD, invokeThisVar) + + invokeMethodVisitor.visitFieldInsn( + Opcodes.GETFIELD, + ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" + ) + + invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) + } + + internal fun visitAlgebraOperation(owner: String, method: String, descriptor: String) { + maxStack++ + invokeMethodVisitor.visitMethodInsn(Opcodes.INVOKEINTERFACE, owner, method, descriptor, true) + visitCastToT() + } + + private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) + + internal companion object { + private val SIGNATURE_LETTERS = mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D" + ) + + internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression" + internal const val LIST_CLASS = "java/util/List" + internal const val MAP_CLASS = "java/util/Map" + internal const val OBJECT_CLASS = "java/lang/Object" + internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" + internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" + internal const val STRING_CLASS = "java/lang/String" + internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" + internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations" + internal const val NUMBER_CLASS = "java/lang/Number" + } +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt new file mode 100644 index 000000000..bdbab2b42 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt @@ -0,0 +1,38 @@ +package scientifik.kmath.expressions + +import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Space + +@PublishedApi +internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { + val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" + + try { + Class.forName(name) + } catch (ignored: ClassNotFoundException) { + return name + } + + return buildName(expression, collision + 1) +} + + +inline fun asmSpace( + algebra: Space, + block: AsmExpressionSpace.() -> AsmExpression +): Expression { + val expression = AsmExpressionSpace(algebra).block() + val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) + expression.invoke(ctx) + return ctx.generate() +} + +inline fun asmField( + algebra: Field, + block: AsmExpressionField.() -> AsmExpression +): Expression { + val expression = AsmExpressionField(algebra).block() + val ctx = AsmGenerationContext(T::class.java, algebra, buildName(expression)) + expression.invoke(ctx) + return ctx.generate() +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 875080616..4370029f6 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -64,7 +64,7 @@ class AsmTest { ) @Test - fun testCProductWithOtherTypeNumber() = testDoubleExpressionValue( + fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue( 25.0, AsmConstProductExpression(AsmVariableExpression("x"), 5f), mapOf("x" to 5.0) @@ -81,21 +81,21 @@ class AsmTest { } @Test - fun testCProductWithCustomTypeNumber() = testDoubleExpressionValue( + fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue( 0.0, AsmConstProductExpression(AsmVariableExpression("x"), CustomZero), mapOf("x" to 5.0) ) @Test - fun testVar() = testDoubleExpressionValue( + fun testVar(): Unit = testDoubleExpressionValue( 10000.0, AsmVariableExpression("x"), mapOf("x" to 10000.0) ) @Test - fun testVarWithDefault() = testDoubleExpressionValue( + fun testVarWithDefault(): Unit = testDoubleExpressionValue( 10000.0, AsmVariableExpression("x", 10000.0), mapOf()