forked from kscience/kmath
Remove duplicated code, optimize constants field generation, add overloads for getOrFail in mapIntrinsics
This commit is contained in:
parent
d7f5d9f53f
commit
63c001648e
@ -59,15 +59,13 @@ import scientifik.kmath.operations.RealField;
|
|||||||
|
|
||||||
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
public final class AsmCompiledExpression_1073786867_0 implements Expression<Double> {
|
||||||
private final RealField algebra;
|
private final RealField algebra;
|
||||||
private final Object[] constants;
|
|
||||||
|
|
||||||
public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) {
|
|
||||||
this.algebra = algebra;
|
|
||||||
this.constants = constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final Double invoke(Map<String, ? extends Double> arguments) {
|
public final Double invoke(Map<String, ? extends Double> arguments) {
|
||||||
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D);
|
return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsmCompiledExpression_1073786867_0(RealField algebra) {
|
||||||
|
this.algebra = algebra;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,17 +38,17 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Type for [algebra]
|
* ASM Type for [algebra].
|
||||||
*/
|
*/
|
||||||
private val tAlgebraType: Type = algebra::class.asm
|
private val tAlgebraType: Type = algebra::class.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T]
|
* ASM type for [T].
|
||||||
*/
|
*/
|
||||||
internal val tType: Type = classOfT.asm
|
internal val tType: Type = classOfT.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for new class
|
* ASM type for new class.
|
||||||
*/
|
*/
|
||||||
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!!
|
||||||
|
|
||||||
@ -72,6 +72,11 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
*/
|
*/
|
||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
|
/**
|
||||||
|
* State if this [AsmBuilder] needs to generate constants field.
|
||||||
|
*/
|
||||||
|
private var hasConstants = true
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
* State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls.
|
||||||
*/
|
*/
|
||||||
@ -108,7 +113,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
* The built instance is cached.
|
* The built instance is cached.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
fun getInstance(): Expression<T> {
|
internal fun getInstance(): Expression<T> {
|
||||||
generatedInstance?.let { return it }
|
generatedInstance?.let { return it }
|
||||||
|
|
||||||
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
if (SIGNATURE_LETTERS.containsKey(classOfT)) {
|
||||||
@ -127,64 +132,6 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
arrayOf(EXPRESSION_TYPE.internalName)
|
arrayOf(EXPRESSION_TYPE.internalName)
|
||||||
)
|
)
|
||||||
|
|
||||||
visitField(
|
|
||||||
access = ACC_PRIVATE or ACC_FINAL,
|
|
||||||
name = "algebra",
|
|
||||||
descriptor = tAlgebraType.descriptor,
|
|
||||||
signature = null,
|
|
||||||
value = null,
|
|
||||||
block = FieldVisitor::visitEnd
|
|
||||||
)
|
|
||||||
|
|
||||||
visitField(
|
|
||||||
access = ACC_PRIVATE or ACC_FINAL,
|
|
||||||
name = "constants",
|
|
||||||
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
|
||||||
signature = null,
|
|
||||||
value = null,
|
|
||||||
block = FieldVisitor::visitEnd
|
|
||||||
)
|
|
||||||
|
|
||||||
visitMethod(
|
|
||||||
ACC_PUBLIC,
|
|
||||||
"<init>",
|
|
||||||
Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE),
|
|
||||||
null,
|
|
||||||
null
|
|
||||||
).instructionAdapter {
|
|
||||||
val thisVar = 0
|
|
||||||
val algebraVar = 1
|
|
||||||
val constantsVar = 2
|
|
||||||
val l0 = label()
|
|
||||||
load(thisVar, classType)
|
|
||||||
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
|
||||||
label()
|
|
||||||
load(thisVar, classType)
|
|
||||||
load(algebraVar, tAlgebraType)
|
|
||||||
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
|
||||||
label()
|
|
||||||
load(thisVar, classType)
|
|
||||||
load(constantsVar, OBJECT_ARRAY_TYPE)
|
|
||||||
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
|
||||||
label()
|
|
||||||
visitInsn(RETURN)
|
|
||||||
val l4 = label()
|
|
||||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
|
||||||
|
|
||||||
visitLocalVariable(
|
|
||||||
"algebra",
|
|
||||||
tAlgebraType.descriptor,
|
|
||||||
null,
|
|
||||||
l0,
|
|
||||||
l4,
|
|
||||||
algebraVar
|
|
||||||
)
|
|
||||||
|
|
||||||
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
|
||||||
visitMaxs(0, 3)
|
|
||||||
visitEnd()
|
|
||||||
}
|
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL,
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
"invoke",
|
"invoke",
|
||||||
@ -251,6 +198,78 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hasConstants = constants.isNotEmpty()
|
||||||
|
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "algebra",
|
||||||
|
descriptor = tAlgebraType.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
if (hasConstants)
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "constants",
|
||||||
|
descriptor = OBJECT_ARRAY_TYPE.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC,
|
||||||
|
"<init>",
|
||||||
|
|
||||||
|
Type.getMethodDescriptor(
|
||||||
|
Type.VOID_TYPE,
|
||||||
|
tAlgebraType,
|
||||||
|
*OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||||
|
|
||||||
|
null,
|
||||||
|
null
|
||||||
|
).instructionAdapter {
|
||||||
|
val thisVar = 0
|
||||||
|
val algebraVar = 1
|
||||||
|
val constantsVar = 2
|
||||||
|
val l0 = label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
invokespecial(OBJECT_TYPE.internalName, "<init>", Type.getMethodDescriptor(Type.VOID_TYPE), false)
|
||||||
|
label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
load(algebraVar, tAlgebraType)
|
||||||
|
putfield(classType.internalName, "algebra", tAlgebraType.descriptor)
|
||||||
|
|
||||||
|
if (hasConstants) {
|
||||||
|
label()
|
||||||
|
load(thisVar, classType)
|
||||||
|
load(constantsVar, OBJECT_ARRAY_TYPE)
|
||||||
|
putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
|
}
|
||||||
|
|
||||||
|
label()
|
||||||
|
visitInsn(RETURN)
|
||||||
|
val l4 = label()
|
||||||
|
visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"algebra",
|
||||||
|
tAlgebraType.descriptor,
|
||||||
|
null,
|
||||||
|
l0,
|
||||||
|
l4,
|
||||||
|
algebraVar
|
||||||
|
)
|
||||||
|
|
||||||
|
if (hasConstants)
|
||||||
|
visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar)
|
||||||
|
|
||||||
|
visitMaxs(0, 3)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +277,7 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
.defineClass(className, classWriter.toByteArray())
|
.defineClass(className, classWriter.toByteArray())
|
||||||
.constructors
|
.constructors
|
||||||
.first()
|
.first()
|
||||||
.newInstance(algebra, constants.toTypedArray()) as Expression<T>
|
.newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression<T>
|
||||||
|
|
||||||
generatedInstance = new
|
generatedInstance = new
|
||||||
return new
|
return new
|
||||||
@ -359,18 +378,20 @@ internal class AsmBuilder<T> internal constructor(
|
|||||||
|
|
||||||
if (defaultValue != null)
|
if (defaultValue != null)
|
||||||
loadTConstant(defaultValue)
|
loadTConstant(defaultValue)
|
||||||
else
|
|
||||||
aconst(null)
|
|
||||||
|
|
||||||
invokestatic(
|
invokestatic(
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
"getOrFail",
|
"getOrFail",
|
||||||
Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE),
|
|
||||||
|
Type.getMethodDescriptor(
|
||||||
|
OBJECT_TYPE,
|
||||||
|
MAP_TYPE,
|
||||||
|
OBJECT_TYPE,
|
||||||
|
*OBJECT_TYPE.wrapToArrayIf { defaultValue != null }),
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
|
|
||||||
checkcast(tType)
|
checkcast(tType)
|
||||||
|
|
||||||
val expectedType = expectationStack.pop()
|
val expectedType = expectationStack.pop()
|
||||||
|
|
||||||
if (expectedType.sort == Type.OBJECT)
|
if (expectedType.sort == Type.OBJECT)
|
||||||
|
@ -22,6 +22,14 @@ private val methodNameAdapters: Map<Pair<String, Int>, String> by lazy {
|
|||||||
internal val KClass<*>.asm: Type
|
internal val KClass<*>.asm: Type
|
||||||
get() = Type.getType(java)
|
get() = Type.getType(java)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns singleton array with this value if the [predicate] is true, returns empty array otherwise.
|
||||||
|
*/
|
||||||
|
internal inline fun <reified T> T.wrapToArrayIf(predicate: (T) -> Boolean): Array<T> = if (predicate(this))
|
||||||
|
arrayOf(this)
|
||||||
|
else
|
||||||
|
emptyArray()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
* Creates an [InstructionAdapter] from this [MethodVisitor].
|
||||||
*/
|
*/
|
||||||
@ -121,7 +129,7 @@ private fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Strin
|
|||||||
/**
|
/**
|
||||||
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
* Builds specialized algebra call with option to fallback to generic algebra operation accepting String.
|
||||||
*/
|
*/
|
||||||
internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
internal inline fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
||||||
context: Algebra<T>,
|
context: Algebra<T>,
|
||||||
name: String,
|
name: String,
|
||||||
fallbackMethodName: String,
|
fallbackMethodName: String,
|
||||||
@ -145,4 +153,3 @@ internal fun <T> AsmBuilder<T>.buildAlgebraOperationCall(
|
|||||||
expectedArity = arity
|
expectedArity = arity
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
package scientifik.kmath.asm.internal
|
|
||||||
|
|
||||||
import org.objectweb.asm.Label
|
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
|
||||||
|
|
||||||
internal fun InstructionAdapter.label(): Label {
|
|
||||||
val l = Label()
|
|
||||||
visitLabel(l)
|
|
||||||
return l
|
|
||||||
}
|
|
@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
package scientifik.kmath.asm.internal
|
package scientifik.kmath.asm.internal
|
||||||
|
|
||||||
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V?): V {
|
@JvmOverloads
|
||||||
return this[key] ?: default ?: error("Parameter not found: $key")
|
internal fun <K, V> Map<K, V>.getOrFail(key: K, default: V? = null): V =
|
||||||
}
|
this[key] ?: default ?: error("Parameter not found: $key")
|
||||||
|
@ -13,8 +13,7 @@ import kotlin.test.assertEquals
|
|||||||
internal class AsmTest {
|
internal class AsmTest {
|
||||||
@Test
|
@Test
|
||||||
fun `compile MST`() {
|
fun `compile MST`() {
|
||||||
val mst = "2+2*(2+2)".parseMath()
|
val res = ComplexField.expression("2+2*(2+2)".parseMath())()
|
||||||
val res = ComplexField.expression(mst)()
|
|
||||||
assertEquals(Complex(10.0, 0.0), res)
|
assertEquals(Complex(10.0, 0.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user