Add advanced specialization for primitive non-bridge methods

This commit is contained in:
Iaroslav 2020-06-18 11:35:20 +07:00
parent 8264806958
commit d6e7eb8143
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
4 changed files with 98 additions and 13 deletions

View File

@ -19,11 +19,10 @@ fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<
when (node) {
is MST.Symbolic -> loadVariable(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) {
algebra.number(node.value)
} else {
val constant = if (algebra is NumericAlgebra<T>)
algebra.number(node.value) else
error("Number literals are not supported in $algebra")
}
loadTConstant(constant)
}
is MST.Unary -> {

View File

@ -6,6 +6,7 @@ import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader
import scientifik.kmath.operations.Algebra
import java.io.File
/**
* ASM Builder is a structure that abstracts building a class that unwraps [AsmExpression] to plain Java expression.
@ -39,12 +40,23 @@ internal class AsmBuilder<T> internal constructor(
private val invokeArgumentsVar: Int = 1
private val constants: MutableList<Any> = mutableListOf()
private lateinit var invokeMethodVisitor: MethodVisitor
private var primitiveMode = false
internal var primitiveType = OBJECT_CLASS
internal var primitiveTypeSig = "L$OBJECT_CLASS;"
internal var primitiveTypeReturnSig = "L$OBJECT_CLASS;"
private var generatedInstance: AsmCompiledExpression<T>? = null
@Suppress("UNCHECKED_CAST")
fun getInstance(): AsmCompiledExpression<T> {
generatedInstance?.let { return it }
if (classOfT in SIGNATURE_LETTERS) {
primitiveMode = true
primitiveTypeSig = SIGNATURE_LETTERS.getValue(classOfT)
primitiveType = SIGNATURE_LETTERS.getValue(classOfT)
primitiveTypeReturnSig = "L$T_CLASS;"
}
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit(
Opcodes.V1_8,
@ -110,6 +122,10 @@ internal class AsmBuilder<T> internal constructor(
val l0 = Label()
visitLabel(l0)
invokeLabel0Visitor()
if (primitiveMode)
boxPrimitiveToNumber()
visitReturnObject()
val l1 = Label()
visitLabel(l1)
@ -171,6 +187,8 @@ internal class AsmBuilder<T> internal constructor(
visitEnd()
}
File("dump.class").writeBytes(classWriter.toByteArray())
val new = classLoader
.defineClass(className, classWriter.toByteArray())
.constructors
@ -182,6 +200,11 @@ internal class AsmBuilder<T> internal constructor(
}
internal fun loadTConstant(value: T) {
if (primitiveMode) {
loadNumberConstant(value as Number)
return
}
if (classOfT in INLINABLE_NUMBERS) {
loadNumberConstant(value as Number)
invokeMethodVisitor.visitCheckCast(T_CLASS)
@ -191,6 +214,25 @@ internal class AsmBuilder<T> internal constructor(
loadConstant(value as Any, T_CLASS)
}
private fun unboxInPrimitiveMode() {
if (!primitiveMode)
return
unboxNumberToPrimitive()
}
private fun boxPrimitiveToNumber() {
invokeMethodVisitor.visitInvokeStatic(T_CLASS, "valueOf", "($primitiveTypeSig)L$T_CLASS;")
}
private fun unboxNumberToPrimitive() {
invokeMethodVisitor.visitInvokeVirtual(
owner = NUMBER_CLASS,
name = NUMBER_CONVERTER_METHODS.getValue(primitiveType),
descriptor = "()${primitiveTypeSig}"
)
}
private fun loadConstant(value: Any, type: String) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
@ -199,13 +241,14 @@ internal class AsmBuilder<T> internal constructor(
visitGetField(owner = slashesClassName, name = "constants", descriptor = "[L$OBJECT_CLASS;")
visitLdcOrIntConstant(idx)
visitGetObjectArrayElement()
invokeMethodVisitor.visitCheckCast(type)
visitCheckCast(type)
unboxInPrimitiveMode()
}
}
private fun loadThis(): Unit = invokeMethodVisitor.visitLoadObjectVar(invokeThisVar)
internal fun loadNumberConstant(value: Number) {
private fun loadNumberConstant(value: Number) {
val clazz = value.javaClass
val c = clazz.name.replace('.', '/')
val sigLetter = SIGNATURE_LETTERS[clazz]
@ -218,7 +261,9 @@ internal class AsmBuilder<T> internal constructor(
else -> invokeMethodVisitor.visitLdcInsn(value)
}
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
if (!primitiveMode)
invokeMethodVisitor.visitInvokeStatic(c, "valueOf", "($sigLetter)L${c};")
return
}
@ -251,6 +296,7 @@ internal class AsmBuilder<T> internal constructor(
)
invokeMethodVisitor.visitCheckCast(T_CLASS)
unboxInPrimitiveMode()
}
internal fun loadAlgebra() {
@ -274,6 +320,7 @@ internal class AsmBuilder<T> internal constructor(
) {
invokeMethodVisitor.visitMethodInsn(opcode, owner, method, descriptor, isInterface)
invokeMethodVisitor.visitCheckCast(T_CLASS)
unboxInPrimitiveMode()
}
internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.visitLdcInsn(string)
@ -290,11 +337,23 @@ internal class AsmBuilder<T> internal constructor(
)
}
private val NUMBER_CONVERTER_METHODS by lazy {
mapOf(
"B" to "byteValue",
"S" to "shortValue",
"I" to "intValue",
"J" to "longValue",
"F" to "floatValue",
"D" to "doubleValue"
)
}
private val INLINABLE_NUMBERS: Set<Class<out Any>> by lazy { SIGNATURE_LETTERS.keys }
internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS =
"scientifik/kmath/asm/internal/AsmCompiledExpression"
internal const val NUMBER_CLASS = "java/lang/Number"
internal const val MAP_CLASS = "java/util/Map"
internal const val OBJECT_CLASS = "java/lang/Object"
internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra"

View File

@ -24,9 +24,9 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
val sig = buildString {
append('(')
repeat(arity) { append("L${AsmBuilder.OBJECT_CLASS};") }
repeat(arity) { append(primitiveTypeSig) }
append(')')
append("L${AsmBuilder.OBJECT_CLASS};")
append(primitiveTypeReturnSig)
}
invokeAlgebraOperation(

View File

@ -1,10 +1,37 @@
package scietifik.kmath.asm
//
//class TestAsmAlgebras {
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.mstInSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import kotlin.test.Test
import kotlin.test.assertEquals
class TestAsmAlgebras {
@Test
fun space() {
val res = ByteRing.mstInSpace {
binaryOperation(
"+",
unaryOperation(
"+",
number(3.toByte()) - (number(2.toByte()) + (multiply(
add(number(1), number(1)),
2
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + symbol("x") + zero
}.compile()("x" to 2.toByte())
assertEquals(16, res)
}
// @Test
// fun space() {
// val res = ByteRing.asmInRing {
// val res = ByteRing.asm {
// binaryOperation(
// "+",
//
@ -63,4 +90,4 @@ package scietifik.kmath.asm
//
// assertEquals(3.0, res)
// }
//}
}