forked from kscience/kmath
Multiple performance improvements related to ASM
1. Argument values are cached in locals 2. Optimized Expression.invoke function 3. lambda=indy is used in kmath-core
This commit is contained in:
parent
0e1e97a3ff
commit
f231d722c6
@ -72,9 +72,9 @@ benchmark {
|
||||
}
|
||||
|
||||
fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() {
|
||||
warmups = 1
|
||||
warmups = 2
|
||||
iterations = 5
|
||||
iterationTime = 1000
|
||||
iterationTime = 2000
|
||||
iterationTimeUnit = "ms"
|
||||
}
|
||||
|
||||
@ -143,7 +143,7 @@ kotlin.sourceSets.all {
|
||||
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
||||
kotlinOptions {
|
||||
jvmTarget = "11"
|
||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all"
|
||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,9 +62,11 @@ internal class ExpressionsInterpretersBenchmark {
|
||||
private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) {
|
||||
val random = Random(0)
|
||||
var sum = 0.0
|
||||
val m = HashMap<Symbol, Double>()
|
||||
|
||||
repeat(times) {
|
||||
sum += expr(x to random.nextDouble())
|
||||
m[x] = random.nextDouble()
|
||||
sum += expr(m)
|
||||
}
|
||||
|
||||
blackhole.consume(sum)
|
||||
|
@ -54,7 +54,7 @@ fun Project.addBenchmarkProperties() {
|
||||
LocalDateTime.parse(it.name, ISO_DATE_TIME).atZone(ZoneId.systemDefault()).toInstant()
|
||||
}
|
||||
|
||||
if (resDirectory == null) {
|
||||
if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
|
||||
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
|
||||
} else {
|
||||
val reports =
|
||||
|
@ -64,9 +64,9 @@ kotlin.sourceSets.all {
|
||||
}
|
||||
|
||||
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
||||
kotlinOptions{
|
||||
kotlinOptions {
|
||||
jvmTarget = "11"
|
||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn"
|
||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,18 +5,22 @@
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.asm.compileToExpression
|
||||
import space.kscience.kmath.expressions.MstField
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||
import space.kscience.kmath.expressions.interpret
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
|
||||
fun main() {
|
||||
val expr = MstField {
|
||||
x * 2.0 + number(2.0) / x - 16.0
|
||||
}
|
||||
}.compileToExpression(DoubleField)
|
||||
|
||||
val m = HashMap<Symbol, Double>()
|
||||
|
||||
repeat(10000000) {
|
||||
expr.interpret(DoubleField, x to 1.0)
|
||||
m[x] = 1.0
|
||||
expr(m)
|
||||
}
|
||||
}
|
@ -26,7 +26,19 @@ import space.kscience.kmath.operations.bindSymbolOrNull
|
||||
*/
|
||||
@PublishedApi
|
||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||
fun AsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
|
||||
is Symbol -> prepareVariable(node.identity)
|
||||
is Unary -> variablesVisitor(node.value)
|
||||
|
||||
is Binary -> {
|
||||
variablesVisitor(node.left)
|
||||
variablesVisitor(node.right)
|
||||
}
|
||||
|
||||
else -> Unit
|
||||
}
|
||||
|
||||
fun AsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
|
||||
is Symbol -> {
|
||||
val symbol = algebra.bindSymbolOrNull(node)
|
||||
|
||||
@ -40,39 +52,47 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
||||
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
|
||||
)
|
||||
|
||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||
algebra.binaryOperationFunction(node.operation).invoke(
|
||||
algebra.number((node.left as Numeric).value),
|
||||
algebra.number((node.right as Numeric).value)
|
||||
algebra.number((node.right as Numeric).value),
|
||||
)
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||
algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
algebra.leftSideNumberOperationFunction(node.operation),
|
||||
) {
|
||||
expressionVisitor(node.left)
|
||||
expressionVisitor(node.right)
|
||||
}
|
||||
|
||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||
algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
algebra.rightSideNumberOperationFunction(node.operation),
|
||||
) {
|
||||
expressionVisitor(node.left)
|
||||
expressionVisitor(node.right)
|
||||
}
|
||||
|
||||
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
expressionVisitor(node.left)
|
||||
expressionVisitor(node.right)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
||||
return AsmBuilder<T>(
|
||||
type,
|
||||
buildName(this),
|
||||
{ variablesVisitor(this@compileWith) },
|
||||
{ expressionVisitor(this@compileWith) },
|
||||
).instance
|
||||
}
|
||||
|
||||
|
||||
|
@ -10,8 +10,7 @@ import org.objectweb.asm.Opcodes.*
|
||||
import org.objectweb.asm.Type.*
|
||||
import org.objectweb.asm.commons.InstructionAdapter
|
||||
import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||
import space.kscience.kmath.expressions.Expression
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.expressions.*
|
||||
import java.lang.invoke.MethodHandles
|
||||
import java.lang.invoke.MethodType
|
||||
import java.nio.file.Paths
|
||||
@ -26,13 +25,14 @@ import kotlin.io.path.writeBytes
|
||||
*
|
||||
* @property T the type of AsmExpression to unwrap.
|
||||
* @property className the unique class name of new loaded class.
|
||||
* @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0.
|
||||
* @property expressionResultCallback the function to apply to this object when generating expression value.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
internal class AsmBuilder<T>(
|
||||
classOfT: Class<*>,
|
||||
private val className: String,
|
||||
private val callbackAtInvokeL0: AsmBuilder<T>.() -> Unit,
|
||||
private val variablesPrepareCallback: AsmBuilder<T>.() -> Unit,
|
||||
private val expressionResultCallback: AsmBuilder<T>.() -> Unit,
|
||||
) {
|
||||
/**
|
||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||
@ -66,12 +66,17 @@ internal class AsmBuilder<T>(
|
||||
*/
|
||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||
|
||||
/**
|
||||
* Local variables indices are indices of symbols in this list.
|
||||
*/
|
||||
private val argumentsLocals = mutableListOf<String>()
|
||||
|
||||
/**
|
||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||
*
|
||||
* The built instance is cached.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
|
||||
val instance: Expression<T> by lazy {
|
||||
val hasConstants: Boolean
|
||||
|
||||
@ -94,26 +99,28 @@ internal class AsmBuilder<T>(
|
||||
).instructionAdapter {
|
||||
invokeMethodVisitor = this
|
||||
visitCode()
|
||||
val l0 = label()
|
||||
callbackAtInvokeL0()
|
||||
val preparingVariables = label()
|
||||
variablesPrepareCallback()
|
||||
val expressionResult = label()
|
||||
expressionResultCallback()
|
||||
areturn(tType)
|
||||
val l1 = label()
|
||||
val end = label()
|
||||
|
||||
visitLocalVariable(
|
||||
"this",
|
||||
classType.descriptor,
|
||||
null,
|
||||
l0,
|
||||
l1,
|
||||
preparingVariables,
|
||||
end,
|
||||
0,
|
||||
)
|
||||
|
||||
visitLocalVariable(
|
||||
"arguments",
|
||||
MAP_TYPE.descriptor,
|
||||
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
||||
l0,
|
||||
l1,
|
||||
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
|
||||
preparingVariables,
|
||||
end,
|
||||
1,
|
||||
)
|
||||
|
||||
@ -199,7 +206,7 @@ internal class AsmBuilder<T>(
|
||||
val binary = classWriter.toByteArray()
|
||||
val cls = classLoader.defineClass(className, binary)
|
||||
|
||||
if (System.getProperty("space.kscience.communicator.prettyapi.dump.generated.classes") == "1")
|
||||
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
||||
Paths.get("$className.class").writeBytes(binary)
|
||||
|
||||
val l = MethodHandles.publicLookup()
|
||||
@ -256,9 +263,11 @@ internal class AsmBuilder<T>(
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke].
|
||||
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
||||
* [loadVariable].
|
||||
*/
|
||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||
if (name in argumentsLocals) return@run
|
||||
load(1, MAP_TYPE)
|
||||
aconst(name)
|
||||
|
||||
@ -270,8 +279,22 @@ internal class AsmBuilder<T>(
|
||||
)
|
||||
|
||||
checkcast(tType)
|
||||
var idx = argumentsLocals.indexOf(name)
|
||||
|
||||
if (idx == -1) {
|
||||
argumentsLocals += name
|
||||
idx = argumentsLocals.lastIndex
|
||||
}
|
||||
|
||||
store(2 + idx, tType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
||||
* with [prepareVariable] first.
|
||||
*/
|
||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
|
||||
|
||||
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
|
||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }
|
||||
|
@ -5,10 +5,21 @@ plugins {
|
||||
// id("com.xcporter.metaview") version "0.0.5"
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-memory"))
|
||||
kotlin {
|
||||
jvm {
|
||||
compilations.all {
|
||||
kotlinOptions {
|
||||
freeCompilerArgs =
|
||||
freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sourceSets {
|
||||
commonMain {
|
||||
dependencies {
|
||||
api(project(":kmath-memory"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ public fun interface Expression<T> {
|
||||
*
|
||||
* @return a value.
|
||||
*/
|
||||
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||
public operator fun <T> Expression<T>.invoke(): T = this(emptyMap())
|
||||
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
@ -38,7 +38,13 @@ public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
||||
* @return a value.
|
||||
*/
|
||||
@JvmName("callBySymbol")
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = this(
|
||||
when (pairs.size) {
|
||||
0 -> emptyMap()
|
||||
1 -> mapOf(pairs[0])
|
||||
else -> hashMapOf(*pairs)
|
||||
}
|
||||
)
|
||||
|
||||
/**
|
||||
* Calls this expression from arguments.
|
||||
@ -47,8 +53,21 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
|
||||
* @return a value.
|
||||
*/
|
||||
@JvmName("callByString")
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = this(
|
||||
when (pairs.size) {
|
||||
0 -> emptyMap()
|
||||
|
||||
1 -> {
|
||||
val (k, v) = pairs[0]
|
||||
mapOf(StringSymbol(k) to v)
|
||||
}
|
||||
|
||||
else -> hashMapOf(*Array<Pair<Symbol, T>>(pairs.size) {
|
||||
val (k, v) = pairs[it]
|
||||
StringSymbol(k) to v
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user