Multiple performance improvements related to ASM

1. Argument values are cached in locals
2. Optimized Expression.invoke function
3. lambda=indy is used in kmath-core
This commit is contained in:
Iaroslav Postovalov 2021-11-15 22:42:00 +07:00 committed by Iaroslav Postovalov
parent 0e1e97a3ff
commit f231d722c6
9 changed files with 127 additions and 48 deletions

View File

@ -72,9 +72,9 @@ benchmark {
} }
fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() { fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() {
warmups = 1 warmups = 2
iterations = 5 iterations = 5
iterationTime = 1000 iterationTime = 2000
iterationTimeUnit = "ms" iterationTimeUnit = "ms"
} }
@ -143,7 +143,7 @@ kotlin.sourceSets.all {
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> { tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
kotlinOptions { kotlinOptions {
jvmTarget = "11" jvmTarget = "11"
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
} }
} }

View File

@ -62,9 +62,11 @@ internal class ExpressionsInterpretersBenchmark {
private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) { private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) {
val random = Random(0) val random = Random(0)
var sum = 0.0 var sum = 0.0
val m = HashMap<Symbol, Double>()
repeat(times) { repeat(times) {
sum += expr(x to random.nextDouble()) m[x] = random.nextDouble()
sum += expr(m)
} }
blackhole.consume(sum) blackhole.consume(sum)

View File

@ -54,7 +54,7 @@ fun Project.addBenchmarkProperties() {
LocalDateTime.parse(it.name, ISO_DATE_TIME).atZone(ZoneId.systemDefault()).toInstant() LocalDateTime.parse(it.name, ISO_DATE_TIME).atZone(ZoneId.systemDefault()).toInstant()
} }
if (resDirectory == null) { if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**." "> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
} else { } else {
val reports = val reports =

View File

@ -64,9 +64,9 @@ kotlin.sourceSets.all {
} }
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> { tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
kotlinOptions{ kotlinOptions {
jvmTarget = "11" jvmTarget = "11"
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
} }
} }

View File

@ -5,18 +5,22 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.MstField import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
fun main() { fun main() {
val expr = MstField { val expr = MstField {
x * 2.0 + number(2.0) / x - 16.0 x * 2.0 + number(2.0) / x - 16.0
} }.compileToExpression(DoubleField)
val m = HashMap<Symbol, Double>()
repeat(10000000) { repeat(10000000) {
expr.interpret(DoubleField, x to 1.0) m[x] = 1.0
expr(m)
} }
} }

View File

@ -26,7 +26,19 @@ import space.kscience.kmath.operations.bindSymbolOrNull
*/ */
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) { fun AsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
is Symbol -> prepareVariable(node.identity)
is Unary -> variablesVisitor(node.value)
is Binary -> {
variablesVisitor(node.left)
variablesVisitor(node.right)
}
else -> Unit
}
fun AsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
is Symbol -> { is Symbol -> {
val symbol = algebra.bindSymbolOrNull(node) val symbol = algebra.bindSymbolOrNull(node)
@ -40,39 +52,47 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
is Unary -> when { is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))) algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
)
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) } else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
} }
is Binary -> when { is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant( algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
algebra.binaryOperationFunction(node.operation).invoke( algebra.binaryOperationFunction(node.operation).invoke(
algebra.number((node.left as Numeric).value), algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value) algebra.number((node.right as Numeric).value),
) )
) )
algebra is NumericAlgebra && node.left is Numeric -> buildCall( algebra is NumericAlgebra && node.left is Numeric -> buildCall(
algebra.leftSideNumberOperationFunction(node.operation)) { algebra.leftSideNumberOperationFunction(node.operation),
visit(node.left) ) {
visit(node.right) expressionVisitor(node.left)
expressionVisitor(node.right)
} }
algebra is NumericAlgebra && node.right is Numeric -> buildCall( algebra is NumericAlgebra && node.right is Numeric -> buildCall(
algebra.rightSideNumberOperationFunction(node.operation)) { algebra.rightSideNumberOperationFunction(node.operation),
visit(node.left) ) {
visit(node.right) expressionVisitor(node.left)
expressionVisitor(node.right)
} }
else -> buildCall(algebra.binaryOperationFunction(node.operation)) { else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
visit(node.left) expressionVisitor(node.left)
visit(node.right) expressionVisitor(node.right)
} }
} }
} }
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance return AsmBuilder<T>(
type,
buildName(this),
{ variablesVisitor(this@compileWith) },
{ expressionVisitor(this@compileWith) },
).instance
} }

View File

@ -10,8 +10,7 @@ import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type.* import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST
import java.lang.invoke.MethodHandles import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType import java.lang.invoke.MethodType
import java.nio.file.Paths import java.nio.file.Paths
@ -26,13 +25,14 @@ import kotlin.io.path.writeBytes
* *
* @property T the type of AsmExpression to unwrap. * @property T the type of AsmExpression to unwrap.
* @property className the unique class name of new loaded class. * @property className the unique class name of new loaded class.
* @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0. * @property expressionResultCallback the function to apply to this object when generating expression value.
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
internal class AsmBuilder<T>( internal class AsmBuilder<T>(
classOfT: Class<*>, classOfT: Class<*>,
private val className: String, private val className: String,
private val callbackAtInvokeL0: AsmBuilder<T>.() -> Unit, private val variablesPrepareCallback: AsmBuilder<T>.() -> Unit,
private val expressionResultCallback: AsmBuilder<T>.() -> Unit,
) { ) {
/** /**
* Internal classloader of [AsmBuilder] with alias to define class from byte array. * Internal classloader of [AsmBuilder] with alias to define class from byte array.
@ -66,12 +66,17 @@ internal class AsmBuilder<T>(
*/ */
private lateinit var invokeMethodVisitor: InstructionAdapter private lateinit var invokeMethodVisitor: InstructionAdapter
/**
* Local variables indices are indices of symbols in this list.
*/
private val argumentsLocals = mutableListOf<String>()
/** /**
* Subclasses, loads and instantiates [Expression] for given parameters. * Subclasses, loads and instantiates [Expression] for given parameters.
* *
* The built instance is cached. * The built instance is cached.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
val instance: Expression<T> by lazy { val instance: Expression<T> by lazy {
val hasConstants: Boolean val hasConstants: Boolean
@ -94,26 +99,28 @@ internal class AsmBuilder<T>(
).instructionAdapter { ).instructionAdapter {
invokeMethodVisitor = this invokeMethodVisitor = this
visitCode() visitCode()
val l0 = label() val preparingVariables = label()
callbackAtInvokeL0() variablesPrepareCallback()
val expressionResult = label()
expressionResultCallback()
areturn(tType) areturn(tType)
val l1 = label() val end = label()
visitLocalVariable( visitLocalVariable(
"this", "this",
classType.descriptor, classType.descriptor,
null, null,
l0, preparingVariables,
l1, end,
0, 0,
) )
visitLocalVariable( visitLocalVariable(
"arguments", "arguments",
MAP_TYPE.descriptor, MAP_TYPE.descriptor,
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
l0, preparingVariables,
l1, end,
1, 1,
) )
@ -199,7 +206,7 @@ internal class AsmBuilder<T>(
val binary = classWriter.toByteArray() val binary = classWriter.toByteArray()
val cls = classLoader.defineClass(className, binary) val cls = classLoader.defineClass(className, binary)
if (System.getProperty("space.kscience.communicator.prettyapi.dump.generated.classes") == "1") if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
Paths.get("$className.class").writeBytes(binary) Paths.get("$className.class").writeBytes(binary)
val l = MethodHandles.publicLookup() val l = MethodHandles.publicLookup()
@ -256,9 +263,11 @@ internal class AsmBuilder<T>(
} }
/** /**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
* [loadVariable].
*/ */
fun loadVariable(name: String): Unit = invokeMethodVisitor.run { fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run
load(1, MAP_TYPE) load(1, MAP_TYPE)
aconst(name) aconst(name)
@ -270,8 +279,22 @@ internal class AsmBuilder<T>(
) )
checkcast(tType) checkcast(tType)
var idx = argumentsLocals.indexOf(name)
if (idx == -1) {
argumentsLocals += name
idx = argumentsLocals.lastIndex
}
store(2 + idx, tType)
} }
/**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first.
*/
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) { inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces } val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }

View File

@ -5,10 +5,21 @@ plugins {
// id("com.xcporter.metaview") version "0.0.5" // id("com.xcporter.metaview") version "0.0.5"
} }
kotlin.sourceSets { kotlin {
commonMain { jvm {
dependencies { compilations.all {
api(project(":kmath-memory")) kotlinOptions {
freeCompilerArgs =
freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
}
}
}
sourceSets {
commonMain {
dependencies {
api(project(":kmath-memory"))
}
} }
} }
} }

View File

@ -29,7 +29,7 @@ public fun interface Expression<T> {
* *
* @return a value. * @return a value.
*/ */
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap()) public operator fun <T> Expression<T>.invoke(): T = this(emptyMap())
/** /**
* Calls this expression from arguments. * Calls this expression from arguments.
@ -38,7 +38,13 @@ public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
* @return a value. * @return a value.
*/ */
@JvmName("callBySymbol") @JvmName("callBySymbol")
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs)) public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = this(
when (pairs.size) {
0 -> emptyMap()
1 -> mapOf(pairs[0])
else -> hashMapOf(*pairs)
}
)
/** /**
* Calls this expression from arguments. * Calls this expression from arguments.
@ -47,8 +53,21 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
* @return a value. * @return a value.
*/ */
@JvmName("callByString") @JvmName("callByString")
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = this(
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) }) when (pairs.size) {
0 -> emptyMap()
1 -> {
val (k, v) = pairs[0]
mapOf(StringSymbol(k) to v)
}
else -> hashMapOf(*Array<Pair<Symbol, T>>(pairs.size) {
val (k, v) = pairs[it]
StringSymbol(k) to v
})
}
)
/** /**