From 63c001648e41770879f7a12f94622134e20cf609 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 02:08:26 +0700 Subject: [PATCH] Remove duplicated code, optimize constants field generation, add overloads for getOrFail in mapIntrinsics --- kmath-ast/README.md | 12 +- .../kmath/asm/internal/AsmBuilder.kt | 155 ++++++++++-------- .../kmath/asm/internal/codegenUtils.kt | 11 +- .../kmath/asm/internal/instructionAdapters.kt | 10 -- .../kmath/asm/internal/mapIntrinsics.kt | 6 +- .../kotlin/scietifik/kmath/ast/AsmTest.kt | 3 +- 6 files changed, 106 insertions(+), 91 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt diff --git a/kmath-ast/README.md b/kmath-ast/README.md index ec80f8169..baeb60c39 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -59,15 +59,13 @@ import scientifik.kmath.operations.RealField; public final class AsmCompiledExpression_1073786867_0 implements Expression { private final RealField algebra; - private final Object[] constants; - - public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) { - this.algebra = algebra; - this.constants = constants; - } public final Double invoke(Map arguments) { - return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x", (Object)null)).doubleValue(), 2.0D); + return (Double)this.algebra.add(((Double)MapIntrinsics.getOrFail(arguments, "x")).doubleValue(), 2.0D); + } + + public AsmCompiledExpression_1073786867_0(RealField algebra) { + this.algebra = algebra; } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index cea6be933..92f808be8 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -38,17 +38,17 @@ internal class AsmBuilder internal constructor( private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) /** - * ASM Type for [algebra] + * ASM Type for [algebra]. */ private val tAlgebraType: Type = algebra::class.asm /** - * ASM type for [T] + * ASM type for [T]. */ internal val tType: Type = classOfT.asm /** - * ASM type for new class + * ASM type for new class. */ private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! @@ -72,6 +72,11 @@ internal class AsmBuilder internal constructor( */ private lateinit var invokeMethodVisitor: InstructionAdapter + /** + * State if this [AsmBuilder] needs to generate constants field. + */ + private var hasConstants = true + /** * State if [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. */ @@ -108,7 +113,7 @@ internal class AsmBuilder internal constructor( * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - fun getInstance(): Expression { + internal fun getInstance(): Expression { generatedInstance?.let { return it } if (SIGNATURE_LETTERS.containsKey(classOfT)) { @@ -127,64 +132,6 @@ internal class AsmBuilder internal constructor( arrayOf(EXPRESSION_TYPE.internalName) ) - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "algebra", - descriptor = tAlgebraType.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "constants", - descriptor = OBJECT_ARRAY_TYPE.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - - visitMethod( - ACC_PUBLIC, - "", - Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), - null, - null - ).instructionAdapter { - val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 - val l0 = label() - load(thisVar, classType) - invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) - label() - load(thisVar, classType) - load(algebraVar, tAlgebraType) - putfield(classType.internalName, "algebra", tAlgebraType.descriptor) - label() - load(thisVar, classType) - load(constantsVar, OBJECT_ARRAY_TYPE) - putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - label() - visitInsn(RETURN) - val l4 = label() - visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) - - visitLocalVariable( - "algebra", - tAlgebraType.descriptor, - null, - l0, - l4, - algebraVar - ) - - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) - visitMaxs(0, 3) - visitEnd() - } - visitMethod( ACC_PUBLIC or ACC_FINAL, "invoke", @@ -251,6 +198,78 @@ internal class AsmBuilder internal constructor( visitEnd() } + hasConstants = constants.isNotEmpty() + + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + if (hasConstants) + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitMethod( + ACC_PUBLIC, + "", + + Type.getMethodDescriptor( + Type.VOID_TYPE, + tAlgebraType, + *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), + + null, + null + ).instructionAdapter { + val thisVar = 0 + val algebraVar = 1 + val constantsVar = 2 + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + + if (hasConstants) { + label() + load(thisVar, classType) + load(constantsVar, OBJECT_ARRAY_TYPE) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + } + + label() + visitInsn(RETURN) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) + + visitLocalVariable( + "algebra", + tAlgebraType.descriptor, + null, + l0, + l4, + algebraVar + ) + + if (hasConstants) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) + + visitMaxs(0, 3) + visitEnd() + } + visitEnd() } @@ -258,7 +277,7 @@ internal class AsmBuilder internal constructor( .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as Expression + .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression generatedInstance = new return new @@ -359,18 +378,20 @@ internal class AsmBuilder internal constructor( if (defaultValue != null) loadTConstant(defaultValue) - else - aconst(null) invokestatic( MAP_INTRINSICS_TYPE.internalName, "getOrFail", - Type.getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, OBJECT_TYPE, OBJECT_TYPE), + + Type.getMethodDescriptor( + OBJECT_TYPE, + MAP_TYPE, + OBJECT_TYPE, + *OBJECT_TYPE.wrapToArrayIf { defaultValue != null }), false ) checkcast(tType) - val expectedType = expectationStack.pop() if (expectedType.sort == Type.OBJECT) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt index 46d07976d..63681ad96 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -22,6 +22,14 @@ private val methodNameAdapters: Map, String> by lazy { internal val KClass<*>.asm: Type get() = Type.getType(java) +/** + * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. + */ +internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array = if (predicate(this)) + arrayOf(this) +else + emptyArray() + /** * Creates an [InstructionAdapter] from this [MethodVisitor]. */ @@ -121,7 +129,7 @@ private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: Strin /** * Builds specialized algebra call with option to fallback to generic algebra operation accepting String. */ -internal fun AsmBuilder.buildAlgebraOperationCall( +internal inline fun AsmBuilder.buildAlgebraOperationCall( context: Algebra, name: String, fallbackMethodName: String, @@ -145,4 +153,3 @@ internal fun AsmBuilder.buildAlgebraOperationCall( expectedArity = arity ) } - diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt deleted file mode 100644 index f47293687..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt +++ /dev/null @@ -1,10 +0,0 @@ -package scientifik.kmath.asm.internal - -import org.objectweb.asm.Label -import org.objectweb.asm.commons.InstructionAdapter - -internal fun InstructionAdapter.label(): Label { - val l = Label() - visitLabel(l) - return l -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt index 7f7126b55..80e83c1bf 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/mapIntrinsics.kt @@ -2,6 +2,6 @@ package scientifik.kmath.asm.internal -internal fun Map.getOrFail(key: K, default: V?): V { - return this[key] ?: default ?: error("Parameter not found: $key") -} +@JvmOverloads +internal fun Map.getOrFail(key: K, default: V? = null): V = + this[key] ?: default ?: error("Parameter not found: $key") diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt index 23203172e..75659cc35 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/ast/AsmTest.kt @@ -13,8 +13,7 @@ import kotlin.test.assertEquals internal class AsmTest { @Test fun `compile MST`() { - val mst = "2+2*(2+2)".parseMath() - val res = ComplexField.expression(mst)() + val res = ComplexField.expression("2+2*(2+2)".parseMath())() assertEquals(Complex(10.0, 0.0), res) }