Remove duplicated code, optimize constants field generation, add overloads for getOrFail in mapIntrinsics

This commit is contained in:
Iaroslav 2020-06-28 02:08:26 +07:00
parent d7f5d9f53f
commit 63c001648e
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
6 changed files with 106 additions and 91 deletions

View File

@ -59,15 +59,13 @@ import scientifik.kmath.operations.RealField;
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
private final RealField algebra;
private final Object[] constants;
public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) {
this.algebra = algebra;
this.constants = constants;
}
public final Double invoke(Map<String, ? extends Double> arguments) {
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D);
}
public AsmCompiledExpression_1073786867_0(RealField algebra) {
this.algebra = algebra;
}
}

View File

@ -38,17 +38,17 @@ internal class AsmBuilder<T> internal constructor(
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
/**
* ASM Type for [algebra]
* ASM Type for [algebra].
*/
private val tAlgebraType: Type = algebra::class.asm
/**
* ASM type for [T]
* ASM type for [T].
*/
internal val tType: Type = classOfT.asm
/**
* ASM type for new class
* ASM type for new class.
*/
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
@ -72,6 +72,11 @@ internal class AsmBuilder<T> internal constructor(
*/
private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* State if this [AsmBuilder] needs to generate constants field.
*/
private var hasConstants = true
/**
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
*/
@ -108,7 +113,7 @@ internal class AsmBuilder<T> internal constructor(
* The built instance is cached.
*/
@Suppress("UNCHECKED_CAST")
fun getInstance(): Expression<T> {
internal fun getInstance(): Expression<T> {
generatedInstance?.let { return it }
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
@ -127,64 +132,6 @@ internal class AsmBuilder<T> internal constructor(
arrayOf(EXPRESSION_TYPE.internalName)
)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitMethod(
ACC_PUBLIC,
"<init>",
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
null,
null
).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = label()
load(thisVar, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
label()
load(thisVar, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
label()
load(thisVar, classType)
load(constantsVar, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
visitMaxs(0, 3)
visitEnd()
}
visitMethod(
ACC_PUBLIC or ACC_FINAL,
"invoke",
@ -251,6 +198,78 @@ internal class AsmBuilder<T> internal constructor(
visitEnd()
}
hasConstants = constants.isNotEmpty()
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "algebra",
descriptor = tAlgebraType.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
if (hasConstants)
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "constants",
descriptor = OBJECT_ARRAY_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd
)
visitMethod(
ACC_PUBLIC,
"<init>",
Type.getMethodDescriptor(
Type.VOID_TYPE,
tAlgebraType,
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null,
null
).instructionAdapter {
val thisVar = 0
val algebraVar = 1
val constantsVar = 2
val l0 = label()
load(thisVar, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
label()
load(thisVar, classType)
load(algebraVar, tAlgebraType)
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
if (hasConstants) {
label()
load(thisVar, classType)
load(constantsVar, OBJECT_ARRAY_TYPE)
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
}
label()
visitInsn(RETURN)
val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
visitLocalVariable(
"algebra",
tAlgebraType.descriptor,
null,
l0,
l4,
algebraVar
)
if (hasConstants)
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
visitMaxs(0, 3)
visitEnd()
}
visitEnd()
}
@ -258,7 +277,7 @@ internal class AsmBuilder<T> internal constructor(
.defineClass(className, classWriter.toByteArray())
.constructors
.first()
.newInstance(algebra, constants.toTypedArray()) as Expression<T>
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
generatedInstance = new
return new
@ -359,18 +378,20 @@ internal class AsmBuilder<T> internal constructor(
if (defaultValue != null)
loadTConstant(defaultValue)
else
aconst(null)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
Type.getMethodDescriptor(
OBJECT_TYPE,
MAP_TYPE,
OBJECT_TYPE,
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
false
)
checkcast(tType)
val expectedType = expectationStack.pop()
if (expectedType.sort == Type.OBJECT)

View File

@ -22,6 +22,14 @@ private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
internal val KClass<*>.asm: Type
get() = Type.getType(java)
/**
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
*/
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> = if (predicate(this))
arrayOf(this)
else
emptyArray()
/**
* Creates an [InstructionAdapter] from this [MethodVisitor].
*/
@ -121,7 +129,7 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Strin
/**
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
*/
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
context: Algebra<T>,
name: String,
fallbackMethodName: String,
@ -145,4 +153,3 @@ internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
expectedArity = arity
)
}

View File

@ -1,10 +0,0 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Label
import org.objectweb.asm.commons.InstructionAdapter
internal fun InstructionAdapter.label(): Label {
val l = Label()
visitLabel(l)
return l
}

View File

@ -2,6 +2,6 @@
package scientifik.kmath.asm.internal
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
return this[key] ?: default ?: error("Parameter not found: $key")
}
@JvmOverloads
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
this[key] ?: default ?: error("Parameter not found: $key")

View File

@ -13,8 +13,7 @@ import kotlin.test.assertEquals
internal class AsmTest {
@Test
fun `compile MST`() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.expression(mst)()
val res = ComplexField.expression("2+2*(2+2)".parseMath())()
assertEquals(Complex(10.0, 0.0), res)
}