diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index cca3d312d..130cc84a3 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -72,9 +72,9 @@ benchmark { } fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() { - warmups = 1 + warmups = 2 iterations = 5 - iterationTime = 1000 + iterationTime = 2000 iterationTimeUnit = "ms" } @@ -143,7 +143,7 @@ kotlin.sourceSets.all { tasks.withType { kotlinOptions { jvmTarget = "11" - freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy" } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt index 63e1511bd..b824a0d69 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt @@ -62,9 +62,11 @@ internal class ExpressionsInterpretersBenchmark { private fun invokeAndSum(expr: Expression, blackhole: Blackhole) { val random = Random(0) var sum = 0.0 + val m = HashMap() repeat(times) { - sum += expr(x to random.nextDouble()) + m[x] = random.nextDouble() + sum += expr(m) } blackhole.consume(sum) diff --git a/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt b/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt index 72c9ff0ad..568bfe535 100644 --- a/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt +++ b/buildSrc/src/main/kotlin/space/kscience/kmath/benchmarks/addBenchmarkProperties.kt @@ -54,7 +54,7 @@ fun Project.addBenchmarkProperties() { LocalDateTime.parse(it.name, ISO_DATE_TIME).atZone(ZoneId.systemDefault()).toInstant() } - if (resDirectory == null) { + if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) { "> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**." } else { val reports = diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index 7b1bce26a..05e562143 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -64,9 +64,9 @@ kotlin.sourceSets.all { } tasks.withType { - kotlinOptions{ + kotlinOptions { jvmTarget = "11" - freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy" } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt index 887d76c42..b37d120fb 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt @@ -5,18 +5,22 @@ package space.kscience.kmath.ast +import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.MstField +import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol.Companion.x -import space.kscience.kmath.expressions.interpret import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke fun main() { val expr = MstField { x * 2.0 + number(2.0) / x - 16.0 - } + }.compileToExpression(DoubleField) + + val m = HashMap() repeat(10000000) { - expr.interpret(DoubleField, x to 1.0) + m[x] = 1.0 + expr(m) } -} \ No newline at end of file +} diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 2426d6ee4..9a7c4b023 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -26,7 +26,19 @@ import space.kscience.kmath.operations.bindSymbolOrNull */ @PublishedApi internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { - fun AsmBuilder.visit(node: MST): Unit = when (node) { + fun AsmBuilder.variablesVisitor(node: MST): Unit = when (node) { + is Symbol -> prepareVariable(node.identity) + is Unary -> variablesVisitor(node.value) + + is Binary -> { + variablesVisitor(node.left) + variablesVisitor(node.right) + } + + else -> Unit + } + + fun AsmBuilder.expressionVisitor(node: MST): Unit = when (node) { is Symbol -> { val symbol = algebra.bindSymbolOrNull(node) @@ -40,39 +52,47 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp is Unary -> when { algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( - algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))) + algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)), + ) - else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) } + else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) } } is Binary -> when { algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant( algebra.binaryOperationFunction(node.operation).invoke( algebra.number((node.left as Numeric).value), - algebra.number((node.right as Numeric).value) + algebra.number((node.right as Numeric).value), ) ) algebra is NumericAlgebra && node.left is Numeric -> buildCall( - algebra.leftSideNumberOperationFunction(node.operation)) { - visit(node.left) - visit(node.right) + algebra.leftSideNumberOperationFunction(node.operation), + ) { + expressionVisitor(node.left) + expressionVisitor(node.right) } algebra is NumericAlgebra && node.right is Numeric -> buildCall( - algebra.rightSideNumberOperationFunction(node.operation)) { - visit(node.left) - visit(node.right) + algebra.rightSideNumberOperationFunction(node.operation), + ) { + expressionVisitor(node.left) + expressionVisitor(node.right) } else -> buildCall(algebra.binaryOperationFunction(node.operation)) { - visit(node.left) - visit(node.right) + expressionVisitor(node.left) + expressionVisitor(node.right) } } } - return AsmBuilder(type, buildName(this)) { visit(this@compileWith) }.instance + return AsmBuilder( + type, + buildName(this), + { variablesVisitor(this@compileWith) }, + { expressionVisitor(this@compileWith) }, + ).instance } diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt index 418d6141b..4e29c77e0 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt @@ -10,8 +10,7 @@ import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type.* import org.objectweb.asm.commons.InstructionAdapter import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.* import java.lang.invoke.MethodHandles import java.lang.invoke.MethodType import java.nio.file.Paths @@ -26,13 +25,14 @@ import kotlin.io.path.writeBytes * * @property T the type of AsmExpression to unwrap. * @property className the unique class name of new loaded class. - * @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0. + * @property expressionResultCallback the function to apply to this object when generating expression value. * @author Iaroslav Postovalov */ internal class AsmBuilder( classOfT: Class<*>, private val className: String, - private val callbackAtInvokeL0: AsmBuilder.() -> Unit, + private val variablesPrepareCallback: AsmBuilder.() -> Unit, + private val expressionResultCallback: AsmBuilder.() -> Unit, ) { /** * Internal classloader of [AsmBuilder] with alias to define class from byte array. @@ -66,12 +66,17 @@ internal class AsmBuilder( */ private lateinit var invokeMethodVisitor: InstructionAdapter + /** + * Local variables indices are indices of symbols in this list. + */ + private val argumentsLocals = mutableListOf() + /** * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ - @Suppress("UNCHECKED_CAST") + @Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE") val instance: Expression by lazy { val hasConstants: Boolean @@ -94,26 +99,28 @@ internal class AsmBuilder( ).instructionAdapter { invokeMethodVisitor = this visitCode() - val l0 = label() - callbackAtInvokeL0() + val preparingVariables = label() + variablesPrepareCallback() + val expressionResult = label() + expressionResultCallback() areturn(tType) - val l1 = label() + val end = label() visitLocalVariable( "this", classType.descriptor, null, - l0, - l1, + preparingVariables, + end, 0, ) visitLocalVariable( "arguments", MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", - l0, - l1, + "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", + preparingVariables, + end, 1, ) @@ -199,7 +206,7 @@ internal class AsmBuilder( val binary = classWriter.toByteArray() val cls = classLoader.defineClass(className, binary) - if (System.getProperty("space.kscience.communicator.prettyapi.dump.generated.classes") == "1") + if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") Paths.get("$className.class").writeBytes(binary) val l = MethodHandles.publicLookup() @@ -256,9 +263,11 @@ internal class AsmBuilder( } /** - * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. + * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using + * [loadVariable]. */ - fun loadVariable(name: String): Unit = invokeMethodVisitor.run { + fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { + if (name in argumentsLocals) return@run load(1, MAP_TYPE) aconst(name) @@ -270,8 +279,22 @@ internal class AsmBuilder( ) checkcast(tType) + var idx = argumentsLocals.indexOf(name) + + if (idx == -1) { + argumentsLocals += name + idx = argumentsLocals.lastIndex + } + + store(2 + idx, tType) } + /** + * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored + * with [prepareVariable] first. + */ + fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType) + inline fun buildCall(function: Function, parameters: AsmBuilder.() -> Unit) { contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces } diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index e4436c1df..143e576f7 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -5,10 +5,21 @@ plugins { // id("com.xcporter.metaview") version "0.0.5" } -kotlin.sourceSets { - commonMain { - dependencies { - api(project(":kmath-memory")) +kotlin { + jvm { + compilations.all { + kotlinOptions { + freeCompilerArgs = + freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy" + } + } + } + + sourceSets { + commonMain { + dependencies { + api(project(":kmath-memory")) + } } } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt index edd020c9a..bd3e6d0ba 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt @@ -29,7 +29,7 @@ public fun interface Expression { * * @return a value. */ -public operator fun Expression.invoke(): T = invoke(emptyMap()) +public operator fun Expression.invoke(): T = this(emptyMap()) /** * Calls this expression from arguments. @@ -38,7 +38,13 @@ public operator fun Expression.invoke(): T = invoke(emptyMap()) * @return a value. */ @JvmName("callBySymbol") -public operator fun Expression.invoke(vararg pairs: Pair): T = invoke(mapOf(*pairs)) +public operator fun Expression.invoke(vararg pairs: Pair): T = this( + when (pairs.size) { + 0 -> emptyMap() + 1 -> mapOf(pairs[0]) + else -> hashMapOf(*pairs) + } +) /** * Calls this expression from arguments. @@ -47,8 +53,21 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = * @return a value. */ @JvmName("callByString") -public operator fun Expression.invoke(vararg pairs: Pair): T = - invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) +public operator fun Expression.invoke(vararg pairs: Pair): T = this( + when (pairs.size) { + 0 -> emptyMap() + + 1 -> { + val (k, v) = pairs[0] + mapOf(StringSymbol(k) to v) + } + + else -> hashMapOf(*Array>(pairs.size) { + val (k, v) = pairs[it] + StringSymbol(k) to v + }) + } +) /**