diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index bb091732e..2345757f1 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -19,11 +19,10 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< when (node) { is MST.Symbolic -> loadVariable(node.value) is MST.Numeric -> { - val constant = if (algebra is NumericAlgebra) { - algebra.number(node.value) - } else { + val constant = if (algebra is NumericAlgebra) + algebra.number(node.value) else error("Number literals are not supported in $algebra") - } + loadTConstant(constant) } is MST.Unary -> { diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 337219029..351d5b754 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.operations.Algebra +import java.io.File /** * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. @@ -39,12 +40,23 @@ internal class AsmBuilder internal constructor( private val invokeArgumentsVar: Int = 1 private val constants: MutableList = mutableListOf() private lateinit var invokeMethodVisitor: MethodVisitor + private var primitiveMode = false + internal var primitiveType = OBJECT_CLASS + internal var primitiveTypeSig = "L$OBJECT_CLASS;" + internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;" private var generatedInstance: AsmCompiledExpression? = null @Suppress("UNCHECKED_CAST") fun getInstance(): AsmCompiledExpression { generatedInstance?.let { return it } + if (classOfT in SIGNATURE_LETTERS) { + primitiveMode = true + primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT) + primitiveType = SIGNATURE_LETTERS.getValue(classOfT) + primitiveTypeReturnSig = "L$T_CLASS;" + } + val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, @@ -110,6 +122,10 @@ internal class AsmBuilder internal constructor( val l0 = Label() visitLabel(l0) invokeLabel0Visitor() + + if (primitiveMode) + boxPrimitiveToNumber() + visitReturnObject() val l1 = Label() visitLabel(l1) @@ -171,6 +187,8 @@ internal class AsmBuilder internal constructor( visitEnd() } + File("dump.class").writeBytes(classWriter.toByteArray()) + val new = classLoader .defineClass(className, classWriter.toByteArray()) .constructors @@ -182,6 +200,11 @@ internal class AsmBuilder internal constructor( } internal fun loadTConstant(value: T) { + if (primitiveMode) { + loadNumberConstant(value as Number) + return + } + if (classOfT in INLINABLE_NUMBERS) { loadNumberConstant(value as Number) invokeMethodVisitor.visitCheckCast(T_CLASS) @@ -191,6 +214,25 @@ internal class AsmBuilder internal constructor( loadConstant(value as Any, T_CLASS) } + private fun unboxInPrimitiveMode() { + if (!primitiveMode) + return + + unboxNumberToPrimitive() + } + + private fun boxPrimitiveToNumber() { + invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;") + } + + private fun unboxNumberToPrimitive() { + invokeMethodVisitor.visitInvokeVirtual( + owner = NUMBER_CLASS, + name = NUMBER_CONVERTER_METHODS.getValue(primitiveType), + descriptor = "()${primitiveTypeSig}" + ) + } + private fun loadConstant(value: Any, type: String) { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex @@ -199,13 +241,14 @@ internal class AsmBuilder internal constructor( visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") visitLdcOrIntConstant(idx) visitGetObjectArrayElement() - invokeMethodVisitor.visitCheckCast(type) + visitCheckCast(type) + unboxInPrimitiveMode() } } private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) - internal fun loadNumberConstant(value: Number) { + private fun loadNumberConstant(value: Number) { val clazz = value.javaClass val c = clazz.name.replace('.', '/') val sigLetter = SIGNATURE_LETTERS[clazz] @@ -218,7 +261,9 @@ internal class AsmBuilder internal constructor( else -> invokeMethodVisitor.visitLdcInsn(value) } - invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + if (!primitiveMode) + invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") + return } @@ -251,6 +296,7 @@ internal class AsmBuilder internal constructor( ) invokeMethodVisitor.visitCheckCast(T_CLASS) + unboxInPrimitiveMode() } internal fun loadAlgebra() { @@ -274,6 +320,7 @@ internal class AsmBuilder internal constructor( ) { invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) invokeMethodVisitor.visitCheckCast(T_CLASS) + unboxInPrimitiveMode() } internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) @@ -290,11 +337,23 @@ internal class AsmBuilder internal constructor( ) } + private val NUMBER_CONVERTER_METHODS by lazy { + mapOf( + "B" to "byteValue", + "S" to "shortValue", + "I" to "intValue", + "J" to "longValue", + "F" to "floatValue", + "D" to "doubleValue" + ) + } + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/asm/internal/AsmCompiledExpression" + internal const val NUMBER_CLASS = "java/lang/Number" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt similarity index 90% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index a39dba6b1..f279c94f4 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -24,9 +24,9 @@ internal fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Stri val sig = buildString { append('(') - repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } + repeat(arity) { append(primitiveTypeSig) } append(')') - append("L${AsmBuilder.OBJECT_CLASS};") + append(primitiveTypeReturnSig) } invokeAlgebraOperation( diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt index a6cb1e247..294da32e9 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmAlgebras.kt @@ -1,10 +1,37 @@ package scietifik.kmath.asm -// -//class TestAsmAlgebras { +import scientifik.kmath.asm.compile +import scientifik.kmath.ast.mstInSpace +import scientifik.kmath.expressions.invoke +import scientifik.kmath.operations.ByteRing +import kotlin.test.Test +import kotlin.test.assertEquals + +class TestAsmAlgebras { + @Test + fun space() { + val res = ByteRing.mstInSpace { + binaryOperation( + "+", + + unaryOperation( + "+", + number(3.toByte()) - (number(2.toByte()) + (multiply( + add(number(1), number(1)), + 2 + ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) + ), + + number(1) + ) + symbol("x") + zero + }.compile()("x" to 2.toByte()) + + assertEquals(16, res) + } + // @Test // fun space() { -// val res = ByteRing.asmInRing { +// val res = ByteRing.asm { // binaryOperation( // "+", // @@ -63,4 +90,4 @@ package scietifik.kmath.asm // // assertEquals(3.0, res) // } -//} +}