diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt index 36607d2b5..9a7ad69f2 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt @@ -7,6 +7,7 @@ import org.objectweb.asm.Opcodes.* import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field import scientifik.kmath.operations.Space +import java.io.File abstract class AsmCompiled(@JvmField val algebra: Algebra, @JvmField val constants: MutableList) { abstract fun evaluate(arguments: Map): T @@ -184,10 +185,13 @@ class AsmGenerationContext(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra(classOfT: Class<*>, private val algebra: Algebra( second.invoke(gen) gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, - method = "times", + owner = AsmGenerationContext.RING_OPERATIONS_CLASS, + method = "multiply", descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" ) } @@ -298,13 +302,13 @@ internal class AsmConstProductExpression( ) : AsmExpression { override fun invoke(gen: AsmGenerationContext) { gen.visitLoadAlgebra() - expr.invoke(gen) gen.visitNumberConstant(const) + expr.invoke(gen) gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_CLASS, + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" ) } } @@ -326,7 +330,6 @@ internal class AsmDivExpression( } } - open class AsmFunctionalExpressionSpace( val space: Space, one: T diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt index fdbc1062e..5b9a8afe8 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt @@ -1,17 +1,34 @@ package scientifik.kmath.expressions import org.objectweb.asm.MethodVisitor -import org.objectweb.asm.Opcodes +import org.objectweb.asm.Opcodes.* fun MethodVisitor.visitLdcOrIConstInsn(value: Int) { when (value) { - -1 -> visitInsn(Opcodes.ICONST_M1) - 0 -> visitInsn(Opcodes.ICONST_0) - 1 -> visitInsn(Opcodes.ICONST_1) - 2 -> visitInsn(Opcodes.ICONST_2) - 3 -> visitInsn(Opcodes.ICONST_3) - 4 -> visitInsn(Opcodes.ICONST_4) - 5 -> visitInsn(Opcodes.ICONST_5) + -1 -> visitInsn(ICONST_M1) + 0 -> visitInsn(ICONST_0) + 1 -> visitInsn(ICONST_1) + 2 -> visitInsn(ICONST_2) + 3 -> visitInsn(ICONST_3) + 4 -> visitInsn(ICONST_4) + 5 -> visitInsn(ICONST_5) else -> visitLdcInsn(value) } } + +private val signatureLetters = mapOf( + java.lang.Byte::class.java to "B", + java.lang.Short::class.java to "S", + java.lang.Integer::class.java to "I", + java.lang.Long::class.java to "J", + java.lang.Float::class.java to "F", + java.lang.Double::class.java to "D", + java.lang.Short::class.java to "S" +) + +fun MethodVisitor.visitBoxedNumberConstant(number: Number) { + val clazz = number.javaClass + val c = clazz.name.replace('.', '/') + visitLdcInsn(number) + visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false) +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index e95f8df76..402307050 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -1,23 +1,68 @@ package scientifik.kmath.expressions +import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals class AsmTest { - @Test - fun test() { - val expr = AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")) - - val gen = AsmGenerationContext( - java.lang.Double::class.java, - RealField, - "MyAsmCompiled" + private fun testExpressionValue( + expectedValue: T, + expr: AsmExpression, + arguments: Map, + algebra: Algebra, + clazz: Class<*> + ) { + assertEquals( + expectedValue, AsmGenerationContext( + clazz, + algebra, + "TestAsmCompiled" + ).also(expr::invoke).generate().evaluate(arguments) ) - - expr.invoke(gen) - val compiled = gen.generate() - val value = compiled.evaluate(mapOf("x" to 25.0)) - assertEquals(26.0, value) } + + @Suppress("UNCHECKED_CAST") + private fun testDoubleExpressionValue( + expectedValue: Double, + expr: AsmExpression, + arguments: Map, + algebra: Algebra = RealField, + clazz: Class = java.lang.Double::class.java as Class + ) = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) + + @Test + fun testSum() = testDoubleExpressionValue( + 25.0, + AsmSumExpression(AsmConstantExpression(1.0), AsmVariableExpression("x")), + mapOf("x" to 24.0) + ) + + @Test + fun testConst() = testDoubleExpressionValue( + 123.0, + AsmConstantExpression(123.0), + mapOf() + ) + + @Test + fun testDiv() = testDoubleExpressionValue( + 0.5, + AsmDivExpression(AsmConstantExpression(1.0), AsmConstantExpression(2.0)), + mapOf() + ) + + @Test + fun testProduct() = testDoubleExpressionValue( + 25.0, + AsmProductExpression(AsmVariableExpression("x"), AsmVariableExpression("x")), + mapOf("x" to 5.0) + ) + + @Test + fun testCProduct() = testDoubleExpressionValue( + 25.0, + AsmConstProductExpression(AsmVariableExpression("x"), 5.0), + mapOf("x" to 5.0) + ) }