From e62cf4fc6586643f6a5adddd2e1f25deb1b756cc Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 8 Dec 2020 14:42:42 +0700 Subject: [PATCH] Rewrite ASM codegen to use curried operators, fix bugs, update benchmarks --- .../ast/ExpressionsInterpretersBenchmark.kt | 49 +-- .../kmath/structures/ArrayBenchmark.kt | 8 +- .../kmath/structures/ViktorBenchmark.kt | 2 +- .../kotlin/kscience/kmath/ast/MST.kt | 3 +- .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 15 +- .../kscience/kmath/ast/MstExpression.kt | 7 +- .../jvmMain/kotlin/kscience/kmath/asm/asm.kt | 49 +-- .../kscience/kmath/asm/internal/AsmBuilder.kt | 300 ++++-------------- .../kscience/kmath/asm/internal/MstType.kt | 20 -- .../kmath/asm/internal/codegenUtils.kt | 120 ------- .../kscience/kmath/asm/TestAsmAlgebras.kt | 36 +-- .../kscience/kmath/asm/TestAsmExpressions.kt | 12 +- .../kmath/asm/TestAsmSpecialization.kt | 2 +- .../kscience/kmath/asm/TestAsmVariables.kt | 2 +- .../kscience/kmath/operations/Algebra.kt | 10 +- .../kmath/operations/NumberAlgebra.kt | 2 +- 16 files changed, 158 insertions(+), 479 deletions(-) rename examples/src/{main => benchmarks}/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt (66%) delete mode 100644 kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt diff --git a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt similarity index 66% rename from examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index a4806ed68..6acaca84d 100644 --- a/examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -4,13 +4,19 @@ import kscience.kmath.asm.compile import kscience.kmath.expressions.Expression import kscience.kmath.expressions.expressionInField import kscience.kmath.expressions.invoke +import kscience.kmath.expressions.symbol import kscience.kmath.operations.Field import kscience.kmath.operations.RealField +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State import kotlin.random.Random -import kotlin.system.measureTimeMillis +@State(Scope.Benchmark) internal class ExpressionsInterpretersBenchmark { private val algebra: Field = RealField + + @Benchmark fun functionalExpression() { val expr = algebra.expressionInField { symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0) @@ -19,6 +25,7 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark fun mstExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -27,6 +34,7 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark fun asmExpression() { val expr = algebra.mstInField { symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) @@ -35,6 +43,13 @@ internal class ExpressionsInterpretersBenchmark { invokeAndSum(expr) } + @Benchmark + fun rawExpression() { + val x by symbol + val expr = Expression { args -> args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 } + invokeAndSum(expr) + } + private fun invokeAndSum(expr: Expression) { val random = Random(0) var sum = 0.0 @@ -46,35 +61,3 @@ internal class ExpressionsInterpretersBenchmark { println(sum) } } - -/** - * This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and - * core FunctionalExpressions API. - * - * The expected rating is: - * - * 1. ASM. - * 2. MST. - * 3. FE. - */ -fun main() { - val benchmark = ExpressionsInterpretersBenchmark() - - val fe = measureTimeMillis { - benchmark.functionalExpression() - } - - println("fe=$fe") - - val mst = measureTimeMillis { - benchmark.mstExpression() - } - - println("mst=$mst") - - val asm = measureTimeMillis { - benchmark.asmExpression() - } - - println("asm=$asm") -} diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt index a91d02253..f7e0d05c6 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ArrayBenchmark.kt @@ -16,17 +16,13 @@ internal class ArrayBenchmark { @Benchmark fun benchmarkBufferRead() { var res = 0 - for (i in 1..size) res += arrayBuffer.get( - size - i - ) + for (i in 1..size) res += arrayBuffer[size - i] } @Benchmark fun nativeBufferRead() { var res = 0 - for (i in 1..size) res += nativeBuffer.get( - size - i - ) + for (i in 1..size) res += nativeBuffer[size - i] } companion object { diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt index 464925ca0..df35b0b78 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/structures/ViktorBenchmark.kt @@ -36,7 +36,7 @@ internal class ViktorBenchmark { @Benchmark fun rawViktor() { - val one = F64Array.full(init = 1.0, shape = *intArrayOf(dim, dim)) + val one = F64Array.full(init = 1.0, shape = intArrayOf(dim, dim)) var res = one repeat(n) { res = res + one } } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt index cc3d38f9c..4ff4d7354 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -64,7 +64,8 @@ public fun Algebra.evaluate(node: MST): T = when (node) { node.left is MST.Numeric && node.right is MST.Numeric -> { val number = RealField - .binaryOperation(node.operation)(node.left.value.toDouble(), node.right.value.toDouble()) + .binaryOperation(node.operation) + .invoke(node.left.value.toDouble(), node.right.value.toDouble()) number(number) } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 44a7f20b9..96703ffb8 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -24,13 +24,17 @@ public object MstSpace : Space, NumericAlgebra { public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) - override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) - override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) + public override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION)(a, b) + public override fun MST.unaryMinus(): MST = unaryOperation(SpaceOperations.MINUS_OPERATION)(this) - override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = + public override fun multiply(a: MST, k: Number): MST.Binary = + binaryOperation(RingOperations.TIMES_OPERATION)(a, number(k)) + + public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstAlgebra.binaryOperation(operation) - override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = MstAlgebra.unaryOperation(operation) + public override fun unaryOperation(operation: String): (arg: MST) -> MST.Unary = + MstAlgebra.unaryOperation(operation) } /** @@ -47,6 +51,7 @@ public object MstRing : Ring, NumericAlgebra { public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION)(a, b) + public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this) public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstSpace.binaryOperation(operation) @@ -71,6 +76,7 @@ public object MstField : Field { public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION)(a, b) + public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this) public override fun binaryOperation(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperation(operation) @@ -105,6 +111,7 @@ public object MstExtendedField : ExtendedField { public override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) public override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) + public override fun MST.unaryMinus(): MST = MstSpace.unaryOperation(SpaceOperations.MINUS_OPERATION)(this) public override fun power(arg: MST, pow: Number): MST.Binary = binaryOperation(PowerOperations.POW_OPERATION)(arg, number(pow)) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt index 8a99fd2f4..4d1b65621 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstExpression.kt @@ -15,7 +15,12 @@ import kotlin.contracts.contract */ public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { - override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) + override fun symbol(value: String): T = try { + algebra.symbol(value) + } catch (ignored: IllegalStateException) { + null + } ?: arguments.getValue(StringSymbol(value)) + override fun unaryOperation(operation: String): (arg: T) -> T = algebra.unaryOperation(operation) override fun binaryOperation(operation: String): (left: T, right: T) -> T = algebra.binaryOperation(operation) diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt index 9ccfa464c..932813182 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -1,13 +1,13 @@ package kscience.kmath.asm import kscience.kmath.asm.internal.AsmBuilder -import kscience.kmath.asm.internal.MstType -import kscience.kmath.asm.internal.buildAlgebraOperationCall import kscience.kmath.asm.internal.buildName import kscience.kmath.ast.MST import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra +import kscience.kmath.operations.NumericAlgebra +import kscience.kmath.operations.RealField /** * Compiles given MST to an Expression using AST compiler. @@ -23,37 +23,46 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp is MST.Symbolic -> { val symbol = try { algebra.symbol(node.value) - } catch (ignored: Throwable) { + } catch (ignored: IllegalStateException) { null } if (symbol != null) - loadTConstant(symbol) + loadObjectConstant(symbol as Any) else loadVariable(node.value) } - is MST.Numeric -> loadNumeric(node.value) + is MST.Numeric -> loadNumberConstant(node.value) + is MST.Unary -> buildCall(algebra.unaryOperation(node.operation)) { visit(node.value) } - is MST.Unary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "unaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.value)) - ) { visit(node.value) } + is MST.Binary -> when { + algebra is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant( + algebra.number( + RealField + .binaryOperation(node.operation) + .invoke(node.left.value.toDouble(), node.right.value.toDouble()) + ) + ) - is MST.Binary -> buildAlgebraOperationCall( - context = algebra, - name = node.operation, - fallbackMethodName = "binaryOperation", - parameterTypes = arrayOf(MstType.fromMst(node.left), MstType.fromMst(node.right)) - ) { - visit(node.left) - visit(node.right) + algebra is NumericAlgebra && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperation(node.operation)) { + visit(node.left) + visit(node.right) + } + + algebra is NumericAlgebra && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperation(node.operation)) { + visit(node.left) + visit(node.right) + } + + else -> buildCall(algebra.binaryOperation(node.operation)) { + visit(node.left) + visit(node.right) + } } } - return AsmBuilder(type, algebra, buildName(this)) { visit(this@compileWith) }.getInstance() + return AsmBuilder(type, buildName(this)) { visit(this@compileWith) }.instance } /** diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index a1e482103..792034232 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -4,26 +4,24 @@ import kscience.kmath.asm.internal.AsmBuilder.ClassLoader import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra -import kscience.kmath.operations.NumericAlgebra import org.objectweb.asm.* import org.objectweb.asm.Opcodes.* import org.objectweb.asm.commons.InstructionAdapter -import java.util.* -import java.util.stream.Collectors +import java.util.stream.Collectors.toMap +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * ASM Builder is a structure that abstracts building a class designated to unwrap [MST] to plain Java expression. * This class uses [ClassLoader] for loading the generated class, then it is able to instantiate the new class. * * @property T the type of AsmExpression to unwrap. - * @property algebra the algebra the applied AsmExpressions use. * @property className the unique class name of new loaded class. * @property invokeLabel0Visitor the function to apply to this object when generating invoke method, label 0. * @author Iaroslav Postovalov */ internal class AsmBuilder internal constructor( - private val classOfT: Class<*>, - private val algebra: Algebra, + classOfT: Class<*>, private val className: String, private val invokeLabel0Visitor: AsmBuilder.() -> Unit, ) { @@ -39,15 +37,10 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - /** - * ASM Type for [algebra]. - */ - private val tAlgebraType: Type = algebra.javaClass.asm - /** * ASM type for [T]. */ - internal val tType: Type = classOfT.asm + private val tType: Type = classOfT.asm /** * ASM type for new class. @@ -69,51 +62,13 @@ internal class AsmBuilder internal constructor( */ private var hasConstants: Boolean = true - /** - * States whether [T] a primitive type, so [AsmBuilder] may generate direct primitive calls. - */ - internal var primitiveMode: Boolean = false - - /** - * Primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMask: Type = OBJECT_TYPE - - /** - * Boxed primitive type to apply for specific primitive calls. Use [OBJECT_TYPE], if not in [primitiveMode]. - */ - internal var primitiveMaskBoxed: Type = OBJECT_TYPE - - /** - * Stack of useful objects types on stack to verify types. - */ - private val typeStack: ArrayDeque = ArrayDeque() - - /** - * Stack of useful objects types on stack expected by algebra calls. - */ - internal val expectationStack: ArrayDeque = ArrayDeque(1).also { it.push(tType) } - - /** - * The cache for instance built by this builder. - */ - private var generatedInstance: Expression? = null - /** * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - internal fun getInstance(): Expression { - generatedInstance?.let { return it } - - if (SIGNATURE_LETTERS.containsKey(classOfT)) { - primitiveMode = true - primitiveMask = SIGNATURE_LETTERS.getValue(classOfT) - primitiveMaskBoxed = tType - } - + val instance: Expression by lazy { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( V1_8, @@ -192,15 +147,6 @@ internal class AsmBuilder internal constructor( hasConstants = constants.isNotEmpty() - visitField( - access = ACC_PRIVATE or ACC_FINAL, - name = "algebra", - descriptor = tAlgebraType.descriptor, - signature = null, - value = null, - block = FieldVisitor::visitEnd - ) - if (hasConstants) visitField( access = ACC_PRIVATE or ACC_FINAL, @@ -214,25 +160,17 @@ internal class AsmBuilder internal constructor( visitMethod( ACC_PUBLIC, "", - - Type.getMethodDescriptor( - Type.VOID_TYPE, - tAlgebraType, - *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), - + Type.getMethodDescriptor(Type.VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), null, null ).instructionAdapter { val thisVar = 0 - val algebraVar = 1 - val constantsVar = 2 + val constantsVar = 1 val l0 = label() load(thisVar, classType) invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) label() load(thisVar, classType) - load(algebraVar, tAlgebraType) - putfield(classType.internalName, "algebra", tAlgebraType.descriptor) if (hasConstants) { label() @@ -246,15 +184,6 @@ internal class AsmBuilder internal constructor( val l4 = label() visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) - visitLocalVariable( - "algebra", - tAlgebraType.descriptor, - null, - l0, - l4, - algebraVar - ) - if (hasConstants) visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) @@ -265,33 +194,55 @@ internal class AsmBuilder internal constructor( visitEnd() } - val new = classLoader +// java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + + classLoader .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, *(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression - - generatedInstance = new - return new + .newInstance(*(constants.toTypedArray().wrapToArrayIf { hasConstants })) as Expression } /** - * Loads a [T] constant from [constants]. + * Loads [java.lang.Object] constant from constants. */ - internal fun loadTConstant(value: T) { - if (classOfT in INLINABLE_NUMBERS) { - val expectedType = expectationStack.pop() - val mustBeBoxed = expectedType.sort == Type.OBJECT - loadNumberConstant(value as Number, mustBeBoxed) + fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { + val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex + loadThis() + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + iconst(idx) + visitInsn(AALOAD) + checkcast(type) + } - if (mustBeBoxed) - invokeMethodVisitor.checkcast(tType) + /** + * Loads `this` variable. + */ + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) - if (mustBeBoxed) typeStack.push(tType) else typeStack.push(primitiveMask) + /** + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive + * constant from the constant pool. + */ + fun loadNumberConstant(value: Number) { + val boxed = value.javaClass.asm + val primitive = BOXED_TO_PRIMITIVES[boxed] + + if (primitive != null) { + when (primitive) { + Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) + Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) + Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) + Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) + } + + box(primitive) return } - loadObjectConstant(value as Any, tType) + loadObjectConstant(value, boxed) } /** @@ -309,77 +260,9 @@ internal class AsmBuilder internal constructor( } /** - * Unboxes the current boxed value and pushes it. + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. */ - private fun unboxTo(primitive: Type) = invokeMethodVisitor.invokevirtual( - NUMBER_TYPE.internalName, - NUMBER_CONVERTER_METHODS.getValue(primitive), - Type.getMethodDescriptor(primitive), - false - ) - - /** - * Loads [java.lang.Object] constant from constants. - */ - private fun loadObjectConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { - val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex - loadThis() - getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) - iconst(idx) - visitInsn(AALOAD) - checkcast(type) - } - - internal fun loadNumeric(value: Number) { - if (expectationStack.peek() == NUMBER_TYPE) { - loadNumberConstant(value, true) - expectationStack.pop() - typeStack.push(NUMBER_TYPE) - } else (algebra as? NumericAlgebra)?.number(value)?.let { loadTConstant(it) } - ?: error("Cannot resolve numeric $value since target algebra is not numeric, and the current operation doesn't accept numbers.") - } - - /** - * Loads this variable. - */ - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) - - /** - * Either loads a numeric constant [value] from the class's constants field or boxes a primitive - * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded - * from it). - */ - private fun loadNumberConstant(value: Number, mustBeBoxed: Boolean) { - val boxed = value.javaClass.asm - val primitive = BOXED_TO_PRIMITIVES[boxed] - - if (primitive != null) { - when (primitive) { - Type.BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) - Type.FLOAT_TYPE -> invokeMethodVisitor.fconst(value.toFloat()) - Type.LONG_TYPE -> invokeMethodVisitor.lconst(value.toLong()) - Type.INT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - Type.SHORT_TYPE -> invokeMethodVisitor.iconst(value.toInt()) - } - - if (mustBeBoxed) - box(primitive) - - return - } - - loadObjectConstant(value, boxed) - - if (!mustBeBoxed) - unboxTo(primitiveMask) - } - - /** - * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be - * provided. - */ - internal fun loadVariable(name: String): Unit = invokeMethodVisitor.run { + fun loadVariable(name: String): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, MAP_TYPE) aconst(name) @@ -391,70 +274,28 @@ internal class AsmBuilder internal constructor( ) checkcast(tType) - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } } - /** - * Loads algebra from according field of the class and casts it to class of [algebra] provided. - */ - internal fun loadAlgebra() { - loadThis() - invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) - } + inline fun buildCall(function: Function, parameters: AsmBuilder.() -> Unit) { + contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } + val `interface` = function.javaClass.interfaces.first { it.interfaces.contains(Function::class.java) } - /** - * Writes a method instruction of opcode with its [owner], [method] and its [descriptor]. The default opcode is - * [Opcodes.INVOKEINTERFACE], since most Algebra functions are declared in interfaces. [loadAlgebra] should be - * called before the arguments and this operation. - * - * The result is casted to [T] automatically. - */ - internal fun invokeAlgebraOperation( - owner: String, - method: String, - descriptor: String, - expectedArity: Int, - opcode: Int = INVOKEINTERFACE, - ) { - run loop@{ - repeat(expectedArity) { - if (typeStack.isEmpty()) return@loop - typeStack.pop() - } - } + val arity = `interface`.methods.find { it.name == "invoke" }?.parameterCount + ?: error("Provided function object doesn't contain invoke method") - invokeMethodVisitor.visitMethodInsn( - opcode, - owner, - method, - descriptor, - opcode == INVOKEINTERFACE + val type = Type.getType(`interface`) + loadObjectConstant(function, type) + parameters(this) + + invokeMethodVisitor.invokeinterface( + type.internalName, + "invoke", + Type.getMethodDescriptor(OBJECT_TYPE, *Array(arity) { OBJECT_TYPE}), ) invokeMethodVisitor.checkcast(tType) - val isLastExpr = expectationStack.size == 1 - val expectedType = expectationStack.pop() - - if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(tType) - else { - unboxTo(primitiveMask) - typeStack.push(primitiveMask) - } } - /** - * Writes a LDC Instruction with string constant provided. - */ - internal fun loadStringConstant(string: String): Unit = invokeMethodVisitor.aconst(string) - internal companion object { /** * Index of `this` variable in invoke method of the built subclass. @@ -490,32 +331,13 @@ internal class AsmBuilder internal constructor( */ private val PRIMITIVES_TO_BOXED: Map by lazy { BOXED_TO_PRIMITIVES.entries.stream().collect( - Collectors.toMap( + toMap( Map.Entry::value, Map.Entry::key ) ) } - /** - * Maps primitive ASM types to [Number] functions unboxing them. - */ - private val NUMBER_CONVERTER_METHODS: Map by lazy { - hashMapOf( - Type.BYTE_TYPE to "byteValue", - Type.SHORT_TYPE to "shortValue", - Type.INT_TYPE to "intValue", - Type.LONG_TYPE to "longValue", - Type.FLOAT_TYPE to "floatValue", - Type.DOUBLE_TYPE to "doubleValue" - ) - } - - /** - * Provides boxed number types values of which can be stored in JVM bytecode constant pool. - */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - /** * ASM type for [Expression]. */ diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt deleted file mode 100644 index 526c27305..000000000 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/MstType.kt +++ /dev/null @@ -1,20 +0,0 @@ -package kscience.kmath.asm.internal - -import kscience.kmath.ast.MST - -/** - * Represents types known in [MST], numbers and general values. - */ -internal enum class MstType { - GENERAL, - NUMBER; - - companion object { - fun fromMst(mst: MST): MstType { - if (mst is MST.Numeric) - return NUMBER - - return GENERAL - } - } -} diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt index ef9751502..d301ddf15 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/codegenUtils.kt @@ -2,29 +2,11 @@ package kscience.kmath.asm.internal import kscience.kmath.ast.MST import kscience.kmath.expressions.Expression -import kscience.kmath.operations.Algebra -import kscience.kmath.operations.FieldOperations -import kscience.kmath.operations.RingOperations -import kscience.kmath.operations.SpaceOperations import org.objectweb.asm.* -import org.objectweb.asm.Opcodes.INVOKEVIRTUAL import org.objectweb.asm.commons.InstructionAdapter -import java.lang.reflect.Method -import java.util.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract -private val methodNameAdapters: Map, String> by lazy { - hashMapOf( - SpaceOperations.PLUS_OPERATION to 2 to "add", - RingOperations.TIMES_OPERATION to 2 to "multiply", - FieldOperations.DIV_OPERATION to 2 to "divide", - SpaceOperations.PLUS_OPERATION to 1 to "unaryPlus", - SpaceOperations.MINUS_OPERATION to 1 to "unaryMinus", - SpaceOperations.MINUS_OPERATION to 2 to "minus" - ) -} - /** * Returns ASM [Type] for given [Class]. * @@ -110,106 +92,4 @@ internal inline fun ClassWriter.visitField( return visitField(access, name, descriptor, signature, value).apply(block) } -private fun AsmBuilder.findSpecific(context: Algebra, name: String, parameterTypes: Array): Method? = - context.javaClass.methods.find { method -> - val nameValid = method.name == name - val arityValid = method.parameters.size == parameterTypes.size - val notBridgeInPrimitive = !(primitiveMode && method.isBridge) - val paramsValid = method.parameterTypes.zip(parameterTypes).all { (type, mstType) -> - !(mstType != MstType.NUMBER && type == java.lang.Number::class.java) - } - - nameValid && arityValid && notBridgeInPrimitive && paramsValid - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity, also builds - * type expectation stack for needed arity. - * - * @author Iaroslav Postovalov - */ -private fun AsmBuilder.buildExpectationStack( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val specific = findSpecific(context, methodNameAdapters[name to arity] ?: name, parameterTypes) - - if (specific != null) - mapTypes(specific, parameterTypes).reversed().forEach { expectationStack.push(it) } - else - expectationStack.addAll(Collections.nCopies(arity, tType)) - - return specific != null -} - -private fun AsmBuilder.mapTypes(method: Method, parameterTypes: Array): List = method - .parameterTypes - .zip(parameterTypes) - .map { (type, mstType) -> - when { - type == java.lang.Number::class.java && mstType == MstType.NUMBER -> AsmBuilder.NUMBER_TYPE - else -> if (primitiveMode) primitiveMask else primitiveMaskBoxed - } - } - -/** - * Checks if the target [context] for code generation contains a method with needed [name] and arity and inserts - * [AsmBuilder.invokeAlgebraOperation] of this method. - * - * @author Iaroslav Postovalov - */ -private fun AsmBuilder.tryInvokeSpecific( - context: Algebra, - name: String, - parameterTypes: Array -): Boolean { - val arity = parameterTypes.size - val theName = methodNameAdapters[name to arity] ?: name - val spec = findSpecific(context, theName, parameterTypes) ?: return false - val owner = context.javaClass.asm - - invokeAlgebraOperation( - owner = owner.internalName, - method = theName, - descriptor = Type.getMethodDescriptor(primitiveMaskBoxed, *mapTypes(spec, parameterTypes).toTypedArray()), - expectedArity = arity, - opcode = INVOKEVIRTUAL - ) - - return true -} - -/** - * Builds specialized [context] call with option to fallback to generic algebra operation accepting [String]. - * - * @author Iaroslav Postovalov - */ -internal inline fun AsmBuilder.buildAlgebraOperationCall( - context: Algebra, - name: String, - fallbackMethodName: String, - parameterTypes: Array, - parameters: AsmBuilder.() -> Unit -) { - contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } - val arity = parameterTypes.size - loadAlgebra() - if (!buildExpectationStack(context, name, parameterTypes)) loadStringConstant(name) - parameters() - - if (!tryInvokeSpecific(context, name, parameterTypes)) invokeAlgebraOperation( - owner = AsmBuilder.ALGEBRA_TYPE.internalName, - method = fallbackMethodName, - - descriptor = Type.getMethodDescriptor( - AsmBuilder.OBJECT_TYPE, - AsmBuilder.STRING_TYPE, - *Array(arity) { AsmBuilder.OBJECT_TYPE } - ), - - expectedArity = arity - ) -} diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt index 5eebfe43d..a1687c1c7 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmAlgebras.kt @@ -10,15 +10,11 @@ import kotlin.test.Test import kotlin.test.assertEquals internal class TestAsmAlgebras { - @Test fun space() { val res1 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -30,11 +26,8 @@ internal class TestAsmAlgebras { }("x" to 2.toByte()) val res2 = ByteRing.mstInSpace { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( number(3.toByte()) - (number(2.toByte()) + (multiply( add(number(1), number(1)), 2 @@ -51,11 +44,8 @@ internal class TestAsmAlgebras { @Test fun ring() { val res1 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 @@ -67,17 +57,13 @@ internal class TestAsmAlgebras { }("x" to 3.toByte()) val res2 = ByteRing.mstInRing { - binaryOperation( - "+", - - unaryOperation( - "+", + binaryOperation("+")( + unaryOperation("+")( (symbol("x") - (2.toByte() + (multiply( add(number(1), number(1)), 2 ) + 1.toByte()))) * 3.0 - 1.toByte() ), - number(1) ) * number(2) }.compile()("x" to 3.toByte()) @@ -88,8 +74,7 @@ internal class TestAsmAlgebras { @Test fun field() { val res1 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one @@ -97,8 +82,7 @@ internal class TestAsmAlgebras { }("x" to 2.0) val res2 = RealField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation( - "+", + +(3 - 2 + 2 * number(1) + 1.0) + binaryOperation("+")( (3.0 - (symbol("x") + (multiply(add(number(1.0), number(1.0)), 2) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt index 7f4d453f7..acd9e21ea 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmExpressions.kt @@ -1,10 +1,11 @@ package kscience.kmath.asm -import kscience.kmath.asm.compile +import kscience.kmath.ast.mstInExtendedField import kscience.kmath.ast.mstInField import kscience.kmath.ast.mstInSpace import kscience.kmath.expressions.invoke import kscience.kmath.operations.RealField +import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals @@ -28,4 +29,13 @@ internal class TestAsmExpressions { val res = RealField.mstInField { symbol("x") * 2 }("x" to 2.0) assertEquals(4.0, res) } + + @Test + fun testMultipleCalls() { + val e = RealField.mstInExtendedField { sin(symbol("x")).pow(4) - 6 * symbol("x") / tanh(symbol("x")) }.compile() + val r = Random(0) + var s = 0.0 + repeat(1000000) { s += e("x" to r.nextDouble()) } + println(s) + } } diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt index eb5a1c8f3..ce9e602d3 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmSpecialization.kt @@ -46,7 +46,7 @@ internal class TestAsmSpecialization { @Test fun testPower() { val expr = RealField - .mstInField { binaryOperation("power")(symbol("x"), number(2)) } + .mstInField { binaryOperation("pow")(symbol("x"), number(2)) } .compile() assertEquals(4.0, expr("x" to 2.0)) diff --git a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt index 7b89c74fa..1011515c8 100644 --- a/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt +++ b/kmath-ast/src/jvmTest/kotlin/kscience/kmath/asm/TestAsmVariables.kt @@ -17,6 +17,6 @@ internal class TestAsmVariables { @Test fun testVariableWithoutDefaultFails() { val expr = ByteRing.mstInRing { symbol("x") } - assertFailsWith { expr() } + assertFailsWith { expr() } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 19ca1a4cc..2661525fb 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -20,12 +20,14 @@ public interface Algebra { /** * Dynamically dispatches an unary operation with name [operation]. */ - public fun unaryOperation(operation: String): (arg: T) -> T + public fun unaryOperation(operation: String): (arg: T) -> T = + error("Unary operation $operation not defined in $this") /** * Dynamically dispatches a binary operation with name [operation]. */ - public fun binaryOperation(operation: String): (left: T, right: T) -> T + public fun binaryOperation(operation: String): (left: T, right: T) -> T = + error("Binary operation $operation not defined in $this") } /** @@ -161,13 +163,13 @@ public interface SpaceOperations : Algebra { override fun unaryOperation(operation: String): (arg: T) -> T = when (operation) { PLUS_OPERATION -> { arg -> arg } MINUS_OPERATION -> { arg -> -arg } - else -> error("Unary operation $operation not defined in $this") + else -> super.unaryOperation(operation) } override fun binaryOperation(operation: String): (left: T, right: T) -> T = when (operation) { PLUS_OPERATION -> ::add MINUS_OPERATION -> { left, right -> left - right } - else -> error("Binary operation $operation not defined in $this") + else -> super.binaryOperation(operation) } public companion object { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt index 611222d34..cd2bd0417 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt @@ -31,7 +31,7 @@ public interface ExtendedFieldOperations : PowerOperations.SQRT_OPERATION -> ::sqrt ExponentialOperations.EXP_OPERATION -> ::exp ExponentialOperations.LN_OPERATION -> ::ln - else -> super.unaryOperation(operation) + else -> super.unaryOperation(operation) } }