Add advanced specialization for primitive non-bridge methods
This commit is contained in:
parent
8264806958
commit
d6e7eb8143
@ -19,11 +19,10 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
|
||||
when (node) {
|
||||
is MST.Symbolic -> loadVariable(node.value)
|
||||
is MST.Numeric -> {
|
||||
val constant = if (algebra is NumericAlgebra<T>) {
|
||||
algebra.number(node.value)
|
||||
} else {
|
||||
val constant = if (algebra is NumericAlgebra<T>)
|
||||
algebra.number(node.value) else
|
||||
error("Number literals are not supported in $algebra")
|
||||
}
|
||||
|
||||
loadTConstant(constant)
|
||||
}
|
||||
is MST.Unary -> {
|
||||
|
@ -6,6 +6,7 @@ import org.objectweb.asm.MethodVisitor
|
||||
import org.objectweb.asm.Opcodes
|
||||
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression.
|
||||
@ -39,12 +40,23 @@ internal class AsmBuilder<T> internal constructor(
|
||||
private val invokeArgumentsVar: Int = 1
|
||||
private val constants: MutableList<Any> = mutableListOf()
|
||||
private lateinit var invokeMethodVisitor: MethodVisitor
|
||||
private var primitiveMode = false
|
||||
internal var primitiveType = OBJECT_CLASS
|
||||
internal var primitiveTypeSig = "L$OBJECT_CLASS;"
|
||||
internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;"
|
||||
private var generatedInstance: AsmCompiledExpression<T>? = null
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun getInstance(): AsmCompiledExpression<T> {
|
||||
generatedInstance?.let { return it }
|
||||
|
||||
if (classOfT in SIGNATURE_LETTERS) {
|
||||
primitiveMode = true
|
||||
primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT)
|
||||
primitiveType = SIGNATURE_LETTERS.getValue(classOfT)
|
||||
primitiveTypeReturnSig = "L$T_CLASS;"
|
||||
}
|
||||
|
||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||
visit(
|
||||
Opcodes.V1_8,
|
||||
@ -110,6 +122,10 @@ internal class AsmBuilder<T> internal constructor(
|
||||
val l0 = Label()
|
||||
visitLabel(l0)
|
||||
invokeLabel0Visitor()
|
||||
|
||||
if (primitiveMode)
|
||||
boxPrimitiveToNumber()
|
||||
|
||||
visitReturnObject()
|
||||
val l1 = Label()
|
||||
visitLabel(l1)
|
||||
@ -171,6 +187,8 @@ internal class AsmBuilder<T> internal constructor(
|
||||
visitEnd()
|
||||
}
|
||||
|
||||
File("dump.class").writeBytes(classWriter.toByteArray())
|
||||
|
||||
val new = classLoader
|
||||
.defineClass(className, classWriter.toByteArray())
|
||||
.constructors
|
||||
@ -182,6 +200,11 @@ internal class AsmBuilder<T> internal constructor(
|
||||
}
|
||||
|
||||
internal fun loadTConstant(value: T) {
|
||||
if (primitiveMode) {
|
||||
loadNumberConstant(value as Number)
|
||||
return
|
||||
}
|
||||
|
||||
if (classOfT in INLINABLE_NUMBERS) {
|
||||
loadNumberConstant(value as Number)
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
@ -191,6 +214,25 @@ internal class AsmBuilder<T> internal constructor(
|
||||
loadConstant(value as Any, T_CLASS)
|
||||
}
|
||||
|
||||
private fun unboxInPrimitiveMode() {
|
||||
if (!primitiveMode)
|
||||
return
|
||||
|
||||
unboxNumberToPrimitive()
|
||||
}
|
||||
|
||||
private fun boxPrimitiveToNumber() {
|
||||
invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;")
|
||||
}
|
||||
|
||||
private fun unboxNumberToPrimitive() {
|
||||
invokeMethodVisitor.visitInvokeVirtual(
|
||||
owner = NUMBER_CLASS,
|
||||
name = NUMBER_CONVERTER_METHODS.getValue(primitiveType),
|
||||
descriptor = "()${primitiveTypeSig}"
|
||||
)
|
||||
}
|
||||
|
||||
private fun loadConstant(value: Any, type: String) {
|
||||
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
|
||||
|
||||
@ -199,13 +241,14 @@ internal class AsmBuilder<T> internal constructor(
|
||||
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;")
|
||||
visitLdcOrIntConstant(idx)
|
||||
visitGetObjectArrayElement()
|
||||
invokeMethodVisitor.visitCheckCast(type)
|
||||
visitCheckCast(type)
|
||||
unboxInPrimitiveMode()
|
||||
}
|
||||
}
|
||||
|
||||
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
|
||||
|
||||
internal fun loadNumberConstant(value: Number) {
|
||||
private fun loadNumberConstant(value: Number) {
|
||||
val clazz = value.javaClass
|
||||
val c = clazz.name.replace('.', '/')
|
||||
val sigLetter = SIGNATURE_LETTERS[clazz]
|
||||
@ -218,7 +261,9 @@ internal class AsmBuilder<T> internal constructor(
|
||||
else -> invokeMethodVisitor.visitLdcInsn(value)
|
||||
}
|
||||
|
||||
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
|
||||
if (!primitiveMode)
|
||||
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -251,6 +296,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
)
|
||||
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
unboxInPrimitiveMode()
|
||||
}
|
||||
|
||||
internal fun loadAlgebra() {
|
||||
@ -274,6 +320,7 @@ internal class AsmBuilder<T> internal constructor(
|
||||
) {
|
||||
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
|
||||
invokeMethodVisitor.visitCheckCast(T_CLASS)
|
||||
unboxInPrimitiveMode()
|
||||
}
|
||||
|
||||
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
|
||||
@ -290,11 +337,23 @@ internal class AsmBuilder<T> internal constructor(
|
||||
)
|
||||
}
|
||||
|
||||
private val NUMBER_CONVERTER_METHODS by lazy {
|
||||
mapOf(
|
||||
"B" to "byteValue",
|
||||
"S" to "shortValue",
|
||||
"I" to "intValue",
|
||||
"J" to "longValue",
|
||||
"F" to "floatValue",
|
||||
"D" to "doubleValue"
|
||||
)
|
||||
}
|
||||
|
||||
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
|
||||
|
||||
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
|
||||
"scientifik/kmath/asm/internal/AsmCompiledExpression"
|
||||
|
||||
internal const val NUMBER_CLASS = "java/lang/Number"
|
||||
internal const val MAP_CLASS = "java/util/Map"
|
||||
internal const val OBJECT_CLASS = "java/lang/Object"
|
||||
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"
|
||||
|
@ -24,9 +24,9 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
|
||||
|
||||
val sig = buildString {
|
||||
append('(')
|
||||
repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") }
|
||||
repeat(arity) { append(primitiveTypeSig) }
|
||||
append(')')
|
||||
append("L${AsmBuilder.OBJECT_CLASS};")
|
||||
append(primitiveTypeReturnSig)
|
||||
}
|
||||
|
||||
invokeAlgebraOperation(
|
@ -1,10 +1,37 @@
|
||||
package scietifik.kmath.asm
|
||||
|
||||
//
|
||||
//class TestAsmAlgebras {
|
||||
import scientifik.kmath.asm.compile
|
||||
import scientifik.kmath.ast.mstInSpace
|
||||
import scientifik.kmath.expressions.invoke
|
||||
import scientifik.kmath.operations.ByteRing
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class TestAsmAlgebras {
|
||||
@Test
|
||||
fun space() {
|
||||
val res = ByteRing.mstInSpace {
|
||||
binaryOperation(
|
||||
"+",
|
||||
|
||||
unaryOperation(
|
||||
"+",
|
||||
number(3.toByte()) - (number(2.toByte()) + (multiply(
|
||||
add(number(1), number(1)),
|
||||
2
|
||||
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
|
||||
),
|
||||
|
||||
number(1)
|
||||
) + symbol("x") + zero
|
||||
}.compile()("x" to 2.toByte())
|
||||
|
||||
assertEquals(16, res)
|
||||
}
|
||||
|
||||
// @Test
|
||||
// fun space() {
|
||||
// val res = ByteRing.asmInRing {
|
||||
// val res = ByteRing.asm {
|
||||
// binaryOperation(
|
||||
// "+",
|
||||
//
|
||||
@ -63,4 +90,4 @@ package scietifik.kmath.asm
|
||||
//
|
||||
// assertEquals(3.0, res)
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user