From 2a34110f1d7ac8deab47f65818641537ca58eff9 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 28 Jun 2020 17:16:15 +0700 Subject: [PATCH] Implement advanced specialization for numeric functions --- .../kotlin/scientifik/kmath/asm/asm.kt | 15 +--- .../kmath/asm/internal/AsmBuilder.kt | 75 +++++++++++----- .../scientifik/kmath/asm/internal/MstType.kt | 17 ++++ .../kmath/asm/internal/codegenUtils.kt | 87 ++++++++++++------- .../kmath/asm/TestAsmSpecialization.kt | 9 ++ 5 files changed, 137 insertions(+), 66 deletions(-) create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt index ef2330533..f0022a43a 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt @@ -1,6 +1,7 @@ package scientifik.kmath.asm import scientifik.kmath.asm.internal.AsmBuilder +import scientifik.kmath.asm.internal.MstType import scientifik.kmath.asm.internal.buildAlgebraOperationCall import scientifik.kmath.asm.internal.buildName import scientifik.kmath.ast.MST @@ -17,28 +18,20 @@ fun MST.compileWith(type: KClass, algebra: Algebra): Expression< fun AsmBuilder.visit(node: MST) { when (node) { is MST.Symbolic -> loadVariable(node.value) - - is MST.Numeric -> { - val constant = if (algebra is NumericAlgebra) - algebra.number(node.value) - else - error("Number literals are not supported in $algebra") - - loadTConstant(constant) - } + is MST.Numeric -> loadNumeric(node.value) is MST.Unary -> buildAlgebraOperationCall( context = algebra, name = node.operation, fallbackMethodName = "unaryOperation", - arity = 1 + parameterTypes = arrayOf(MstType.fromMst(node.value)) ) { visit(node.value) } is MST.Binary -> buildAlgebraOperationCall( context = algebra, name = node.operation, fallbackMethodName = "binaryOperation", - arity = 2 + parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right)) ) { visit(node.left) visit(node.right) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index d51d17992..e550bc563 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -7,7 +7,9 @@ import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.NumericAlgebra import java.util.* +import java.util.stream.Collectors import kotlin.reflect.KClass /** @@ -20,7 +22,7 @@ import kotlin.reflect.KClass * @param invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. */ internal class AsmBuilder internal constructor( - private val classOfT: KClass<*>, + internal val classOfT: KClass<*>, private val algebra: Algebra, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit @@ -100,7 +102,7 @@ internal class AsmBuilder internal constructor( /** * Stack of useful objects types on stack expected by algebra calls. */ - internal val expectationStack: ArrayDeque = ArrayDeque().apply { push(tType) } + internal val expectationStack: ArrayDeque = ArrayDeque(listOf(tType)) /** * The cache for instance built by this builder. @@ -286,42 +288,50 @@ internal class AsmBuilder internal constructor( /** * Loads a [T] constant from [constants]. */ - internal fun loadTConstant(value: T) { + private fun loadTConstant(value: T) { if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop() val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) + + if (mustBeBoxed) + invokeMethodVisitor.checkcast(tType) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) return } - loadConstant(value as Any, tType) + loadObjectConstant(value as Any, tType) } /** * Boxes the current value and pushes it. */ - private fun box(): Unit = invokeMethodVisitor.invokestatic( - tType.internalName, - "valueOf", - Type.getMethodDescriptor(tType, primitiveMask), - false - ) + private fun box(primitive: Type) { + val r = PRIMITIVES_TO_BOXED.getValue(primitive) + + invokeMethodVisitor.invokestatic( + r.internalName, + "valueOf", + Type.getMethodDescriptor(r, primitive), + false + ) + } /** * Unboxes the current boxed value and pushes it. */ - private fun unbox(): Unit = invokeMethodVisitor.invokevirtual( + private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(primitiveMask), - Type.getMethodDescriptor(primitiveMask), + NUMBER_CONVERTER_METHODS.getValue(primitive), + Type.getMethodDescriptor(primitive), false ) /** * Loads [java.lang.Object] constant from constants. */ - private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { + private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) @@ -330,6 +340,15 @@ internal class AsmBuilder internal constructor( checkcast(type) } + fun loadNumeric(value: Number) { + if (expectationStack.peek() == NUMBER_TYPE) { + loadNumberConstant(value, true) + expectationStack.pop() + typeStack.push(NUMBER_TYPE) + } else (algebra as? NumericAlgebra)?.number(value)?.let { loadTConstant(it) } + ?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.") + } + /** * Loads this variable. */ @@ -354,18 +373,16 @@ internal class AsmBuilder internal constructor( Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) } - if (mustBeBoxed) { - box() - invokeMethodVisitor.checkcast(tType) - } + if (mustBeBoxed) + box(primitive) return } - loadConstant(value, boxed) + loadObjectConstant(value, boxed) - if (!mustBeBoxed) unbox() - else invokeMethodVisitor.checkcast(tType) + if (!mustBeBoxed) + unboxTo(primitiveMask) } /** @@ -397,7 +414,7 @@ internal class AsmBuilder internal constructor( if (expectedType.sort == Type.OBJECT) typeStack.push(tType) else { - unbox() + unboxTo(primitiveMask) typeStack.push(primitiveMask) } } @@ -446,7 +463,7 @@ internal class AsmBuilder internal constructor( if (expectedType.sort == Type.OBJECT || isLastExpr) typeStack.push(tType) else { - unbox() + unboxTo(primitiveMask) typeStack.push(primitiveMask) } } @@ -476,6 +493,18 @@ internal class AsmBuilder internal constructor( */ private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } + /** + * Maps JVM primitive numbers boxed ASM types to their primitive ASM types. + */ + private val PRIMITIVES_TO_BOXED: Map by lazy { + BOXED_TO_PRIMITIVES.entries.stream().collect( + Collectors.toMap( + Map.Entry::value, + Map.Entry::key + ) + ) + } + /** * Maps primitive ASM types to [Number] functions unboxing them. */ diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt new file mode 100644 index 000000000..bf73d304b --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MstType.kt @@ -0,0 +1,17 @@ +package scientifik.kmath.asm.internal + +import scientifik.kmath.ast.MST + +internal enum class MstType { + GENERAL, + NUMBER; + + companion object { + fun fromMst(mst: MST): MstType { + if (mst is MST.Numeric) + return NUMBER + + return GENERAL + } + } +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt index 63681ad96..a637289b8 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/codegenUtils.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.ast.MST import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra +import java.lang.reflect.Method import kotlin.reflect.KClass private val methodNameAdapters: Map, String> by lazy { @@ -25,10 +26,8 @@ internal val KClass<*>.asm: Type /** * Returns singleton array with this value if the [predicate] is true, returns empty array otherwise. */ -internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array = if (predicate(this)) - arrayOf(this) -else - emptyArray() +internal inline fun T.wrapToArrayIf(predicate: (T) -> Boolean): Array = + if (predicate(this)) arrayOf(this) else emptyArray() /** * Creates an [InstructionAdapter] from this [MethodVisitor]. @@ -44,11 +43,7 @@ internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Un /** * Constructs a [Label], then applies it to this visitor. */ -internal fun MethodVisitor.label(): Label { - val l = Label() - visitLabel(l) - return l -} +internal fun MethodVisitor.label(): Label = Label().also { visitLabel(it) } /** * Creates a class name for [Expression] subclassed to implement [mst] provided. @@ -81,44 +76,71 @@ internal inline fun ClassWriter.visitField( block: FieldVisitor.() -> Unit ): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) +private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = + context.javaClass.methods.find { method -> + val nameValid = method.name == name + val arityValid = method.parameters.size == parameterTypes.size + val notBridgeInPrimitive = !(primitiveMode && method.isBridge) + + val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) -> + !(mstType != MstType.NUMBER && type == java.lang.Number::class.java) + } + + nameValid && arityValid && notBridgeInPrimitive && paramsValid + } + /** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity], also builds + * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds * type expectation stack for needed arity. * * @return `true` if contains, else `false`. */ -private fun AsmBuilder.buildExpectationStack(context: Algebra, name: String, arity: Int): Boolean { - val theName = methodNameAdapters[name to arity] ?: name - val hasSpecific = context.javaClass.methods.find { it.name == theName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) primitiveMask else tType - repeat(arity) { expectationStack.push(t) } - return hasSpecific +private fun AsmBuilder.buildExpectationStack( + context: Algebra, + name: String, + parameterTypes: Array +): Boolean { + val arity = parameterTypes.size + val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes) + + if (specific != null) + mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } + else + repeat(arity) { expectationStack.push(tType) } + + return specific != null } +private fun AsmBuilder.mapTypes(method: Method, parameterTypes: Array): List = method + .parameterTypes + .zip(parameterTypes) + .map { (type, mstType) -> + when { + type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE + else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed + } + } + /** - * Checks if the target [context] for code generation contains a method with needed [name] and [arity] and inserts + * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts * [AsmBuilder.invokeAlgebraOperation] of this method. * * @return `true` if contains, else `false`. */ -private fun AsmBuilder.tryInvokeSpecific(context: Algebra, name: String, arity: Int): Boolean { +private fun AsmBuilder.tryInvokeSpecific( + context: Algebra, + name: String, + parameterTypes: Array +): Boolean { + val arity = parameterTypes.size val theName = methodNameAdapters[name to arity] ?: name - - context.javaClass.methods.find { - var suitableSignature = it.name == theName && it.parameters.size == arity - - if (primitiveMode && it.isBridge) - suitableSignature = false - - suitableSignature - } ?: return false - + val spec = findSpecific(context, theName, parameterTypes) ?: return false val owner = context::class.asm invokeAlgebraOperation( owner = owner.internalName, method = theName, - descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *Array(arity) { primitiveMask }), + descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()), expectedArity = arity, opcode = INVOKEVIRTUAL ) @@ -133,14 +155,15 @@ internal inline fun AsmBuilder.buildAlgebraOperationCall( context: Algebra, name: String, fallbackMethodName: String, - arity: Int, + parameterTypes: Array, parameters: AsmBuilder.() -> Unit ) { + val arity = parameterTypes.size loadAlgebra() - if (!buildExpectationStack(context, name, arity)) loadStringConstant(name) + if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) parameters() - if (!tryInvokeSpecific(context, name, arity)) invokeAlgebraOperation( + if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation( owner = AsmBuilder.ALGEBRA_TYPE.internalName, method = fallbackMethodName, diff --git a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt index b571e076f..a88431e9d 100644 --- a/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/scietifik/kmath/asm/TestAsmSpecialization.kt @@ -43,4 +43,13 @@ internal class TestAsmSpecialization { val expr = RealField.mstInField { binaryOperation("/", symbol("x"), symbol("x")) }.compile() assertEquals(1.0, expr("x" to 2.0)) } + + @Test + fun testPower() { + val expr = RealField + .mstInField { binaryOperation("power", symbol("x"), number(2)) } + .compile() + + assertEquals(4.0, expr("x" to 2.0)) + } }