Remove constant number allocation hack and support uncommon Number implementations to be available in constants

This commit is contained in:
Iaroslav 2020-06-06 22:05:25 +07:00
parent cdb24ea8e2
commit fb74e74b01
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 32 additions and 27 deletions

View File

@ -8,7 +8,7 @@ import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<T>) { abstract class AsmCompiled<T>(@JvmField val algebra: Algebra<T>, @JvmField val constants: MutableList<out Any>) {
abstract fun evaluate(arguments: Map<String, T>): T abstract fun evaluate(arguments: Map<String, T>): T
} }
@ -29,7 +29,7 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
private val evaluateThisVar: Int = 0 private val evaluateThisVar: Int = 0
private val evaluateArgumentsVar: Int = 1 private val evaluateArgumentsVar: Int = 1
private var maxStack: Int = 0 private var maxStack: Int = 0
private lateinit var constants: MutableList<T> private lateinit var constants: MutableList<Any>
private lateinit var asmCompiledClassWriter: ClassWriter private lateinit var asmCompiledClassWriter: ClassWriter
private lateinit var evaluateMethodVisitor: MethodVisitor private lateinit var evaluateMethodVisitor: MethodVisitor
private lateinit var evaluateL0: Label private lateinit var evaluateL0: Label
@ -178,7 +178,9 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
.newInstance(algebra, constants) as AsmCompiled<T> .newInstance(algebra, constants) as AsmCompiled<T>
} }
fun visitLoadFromConstants(value: T) { fun visitLoadFromConstants(value: T) = visitLoadAnyFromConstants(value as Any)
fun visitLoadAnyFromConstants(value: Any) {
val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex
maxStack++ maxStack++
@ -195,7 +197,24 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
fun visitNumberConstant(value: Number) { fun visitNumberConstant(value: Number) {
maxStack++ maxStack++
evaluateMethodVisitor.visitBoxedNumberConstant(value) val clazz = value.javaClass
val c = clazz.name.replace('.', '/')
val sigLetter = SIGNATURE_LETTERS[clazz]
if (sigLetter != null) {
when (value) {
is Int -> evaluateMethodVisitor.visitLdcOrIConstInsn(value)
is Double -> evaluateMethodVisitor.visitLdcOrDConstInsn(value)
is Float -> evaluateMethodVisitor.visitLdcOrFConstInsn(value)
else -> evaluateMethodVisitor.visitLdcInsn(value)
}
evaluateMethodVisitor.visitMethodInsn(INVOKESTATIC, c, "valueOf", "($sigLetter)L${c};", false)
return
}
visitLoadAnyFromConstants(value)
} }
fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run { fun visitLoadFromVariables(name: String, defaultValue: T? = null) = evaluateMethodVisitor.run {
@ -243,6 +262,15 @@ class AsmGenerationContext<T>(classOfT: Class<*>, private val algebra: Algebra<T
private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS) private fun visitCastToT(): Unit = evaluateMethodVisitor.visitTypeInsn(CHECKCAST, T_CLASS)
companion object { companion object {
private val SIGNATURE_LETTERS = mapOf(
java.lang.Byte::class.java to "B",
java.lang.Short::class.java to "S",
java.lang.Integer::class.java to "I",
java.lang.Long::class.java to "J",
java.lang.Float::class.java to "F",
java.lang.Double::class.java to "D"
)
const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled" const val ASM_COMPILED_CLASS = "scientifik/kmath/expressions/AsmCompiled"
const val LIST_CLASS = "java/util/List" const val LIST_CLASS = "java/util/List"
const val MAP_CLASS = "java/util/Map" const val MAP_CLASS = "java/util/Map"

View File

@ -26,26 +26,3 @@ fun MethodVisitor.visitLdcOrFConstInsn(value: Float) = when (value) {
2f -> visitInsn(FCONST_2) 2f -> visitInsn(FCONST_2)
else -> visitLdcInsn(value) else -> visitLdcInsn(value)
} }
private val signatureLetters = mapOf(
java.lang.Byte::class.java to "B",
java.lang.Short::class.java to "S",
java.lang.Integer::class.java to "I",
java.lang.Long::class.java to "J",
java.lang.Float::class.java to "F",
java.lang.Double::class.java to "D"
)
fun MethodVisitor.visitBoxedNumberConstant(number: Number) {
val clazz = number.javaClass
val c = clazz.name.replace('.', '/')
when (number) {
is Int -> visitLdcOrIConstInsn(number)
is Double -> visitLdcOrDConstInsn(number)
is Float -> visitLdcOrFConstInsn(number)
else -> visitLdcInsn(number)
}
visitMethodInsn(INVOKESTATIC, c, "valueOf", "(${signatureLetters[clazz]})L${c};", false)
}