Add advanced specialization for primitive non-bridge methods

This commit is contained in:
Iaroslav 2020-06-18 11:35:20 +07:00
parent 8264806958
commit d6e7eb8143
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
4 changed files with 98 additions and 13 deletions

View File

@ -19,11 +19,10 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> { is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) { val constant = if (algebra is NumericAlgebra<T>)
algebra.number(node.value) algebra.number(node.value) else
} else {
error("Number literals are not supported in $algebra") error("Number literals are not supported in $algebra")
}
loadTConstant(constant) loadTConstant(constant)
} }
is MST.Unary -> { is MST.Unary -> {

View File

@ -6,6 +6,7 @@ import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import java.io.File
/** /**
* ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression. * ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression.
@ -39,12 +40,23 @@ internal class AsmBuilder<T> internal constructor(
private val invokeArgumentsVar: Int = 1 private val invokeArgumentsVar: Int = 1
private val constants: MutableList<Any> = mutableListOf() private val constants: MutableList<Any> = mutableListOf()
private lateinit var invokeMethodVisitor: MethodVisitor private lateinit var invokeMethodVisitor: MethodVisitor
private var primitiveMode = false
internal var primitiveType = OBJECT_CLASS
internal var primitiveTypeSig = "L$OBJECT_CLASS;"
internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;"
private var generatedInstance: AsmCompiledExpression<T>? = null private var generatedInstance: AsmCompiledExpression<T>? = null
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
fun getInstance(): AsmCompiledExpression<T> { fun getInstance(): AsmCompiledExpression<T> {
generatedInstance?.let { return it } generatedInstance?.let { return it }
if (classOfT in SIGNATURE_LETTERS) {
primitiveMode = true
primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT)
primitiveType = SIGNATURE_LETTERS.getValue(classOfT)
primitiveTypeReturnSig = "L$T_CLASS;"
}
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit( visit(
Opcodes.V1_8, Opcodes.V1_8,
@ -110,6 +122,10 @@ internal class AsmBuilder<T> internal constructor(
val l0 = Label() val l0 = Label()
visitLabel(l0) visitLabel(l0)
invokeLabel0Visitor() invokeLabel0Visitor()
if (primitiveMode)
boxPrimitiveToNumber()
visitReturnObject() visitReturnObject()
val l1 = Label() val l1 = Label()
visitLabel(l1) visitLabel(l1)
@ -171,6 +187,8 @@ internal class AsmBuilder<T> internal constructor(
visitEnd() visitEnd()
} }
File("dump.class").writeBytes(classWriter.toByteArray())
val new = classLoader val new = classLoader
.defineClass(className, classWriter.toByteArray()) .defineClass(className, classWriter.toByteArray())
.constructors .constructors
@ -182,6 +200,11 @@ internal class AsmBuilder<T> internal constructor(
} }
internal fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (primitiveMode) {
loadNumberConstant(value as Number)
return
}
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
loadNumberConstant(value as Number) loadNumberConstant(value as Number)
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.visitCheckCast(T_CLASS)
@ -191,6 +214,25 @@ internal class AsmBuilder<T> internal constructor(
loadConstant(value as Any, T_CLASS) loadConstant(value as Any, T_CLASS)
} }
private fun unboxInPrimitiveMode() {
if (!primitiveMode)
return
unboxNumberToPrimitive()
}
private fun boxPrimitiveToNumber() {
invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;")
}
private fun unboxNumberToPrimitive() {
invokeMethodVisitor.visitInvokeVirtual(
owner = NUMBER_CLASS,
name = NUMBER_CONVERTER_METHODS.getValue(primitiveType),
descriptor = "()${primitiveTypeSig}"
)
}
private fun loadConstant(value: Any, type: String) { private fun loadConstant(value: Any, type: String) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
@ -199,13 +241,14 @@ internal class AsmBuilder<T> internal constructor(
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;") visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;")
visitLdcOrIntConstant(idx) visitLdcOrIntConstant(idx)
visitGetObjectArrayElement() visitGetObjectArrayElement()
invokeMethodVisitor.visitCheckCast(type) visitCheckCast(type)
unboxInPrimitiveMode()
} }
} }
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar) private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
internal fun loadNumberConstant(value: Number) { private fun loadNumberConstant(value: Number) {
val clazz = value.javaClass val clazz = value.javaClass
val c = clazz.name.replace('.', '/') val c = clazz.name.replace('.', '/')
val sigLetter = SIGNATURE_LETTERS[clazz] val sigLetter = SIGNATURE_LETTERS[clazz]
@ -218,7 +261,9 @@ internal class AsmBuilder<T> internal constructor(
else -> invokeMethodVisitor.visitLdcInsn(value) else -> invokeMethodVisitor.visitLdcInsn(value)
} }
if (!primitiveMode)
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};") invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
return return
} }
@ -251,6 +296,7 @@ internal class AsmBuilder<T> internal constructor(
) )
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.visitCheckCast(T_CLASS)
unboxInPrimitiveMode()
} }
internal fun loadAlgebra() { internal fun loadAlgebra() {
@ -274,6 +320,7 @@ internal class AsmBuilder<T> internal constructor(
) { ) {
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface) invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
invokeMethodVisitor.visitCheckCast(T_CLASS) invokeMethodVisitor.visitCheckCast(T_CLASS)
unboxInPrimitiveMode()
} }
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string) internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
@ -290,11 +337,23 @@ internal class AsmBuilder<T> internal constructor(
) )
} }
private val NUMBER_CONVERTER_METHODS by lazy {
mapOf(
"B" to "byteValue",
"S" to "shortValue",
"I" to "intValue",
"J" to "longValue",
"F" to "floatValue",
"D" to "doubleValue"
)
}
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys } private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
"scientifik/kmath/asm/internal/AsmCompiledExpression" "scientifik/kmath/asm/internal/AsmCompiledExpression"
internal const val NUMBER_CLASS = "java/lang/Number"
internal const val MAP_CLASS = "java/util/Map" internal const val MAP_CLASS = "java/util/Map"
internal const val OBJECT_CLASS = "java/lang/Object" internal const val OBJECT_CLASS = "java/lang/Object"
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"

View File

@ -24,9 +24,9 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
val sig = buildString { val sig = buildString {
append('(') append('(')
repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") } repeat(arity) { append(primitiveTypeSig) }
append(')') append(')')
append("L${AsmBuilder.OBJECT_CLASS};") append(primitiveTypeReturnSig)
} }
invokeAlgebraOperation( invokeAlgebraOperation(

View File

@ -1,10 +1,37 @@
package scietifik.kmath.asm package scietifik.kmath.asm
// import scientifik.kmath.asm.compile
//class TestAsmAlgebras { import scientifik.kmath.ast.mstInSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import kotlin.test.Test
import kotlin.test.assertEquals
class TestAsmAlgebras {
@Test
fun space() {
val res = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + symbol("x") + zero
}.compile()("x" to 2.toByte())
assertEquals(16, res)
}
// @Test // @Test
// fun space() { // fun space() {
// val res = ByteRing.asmInRing { // val res = ByteRing.asm {
// binaryOperation( // binaryOperation(
// "+", // "+",
// //
@ -63,4 +90,4 @@ package scietifik.kmath.asm
// //
// assertEquals(3.0, res) // assertEquals(3.0, res)
// } // }
//} }