forked from kscience/kmath
Multiple performance improvements related to ASM
1. Argument values are cached in locals 2. Optimized Expression.invoke function 3. lambda=indy is used in kmath-core
This commit is contained in:
parent
0e1e97a3ff
commit
f231d722c6
@ -72,9 +72,9 @@ benchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() {
|
fun kotlinx.benchmark.gradle.BenchmarkConfiguration.commonConfiguration() {
|
||||||
warmups = 1
|
warmups = 2
|
||||||
iterations = 5
|
iterations = 5
|
||||||
iterationTime = 1000
|
iterationTime = 2000
|
||||||
iterationTimeUnit = "ms"
|
iterationTimeUnit = "ms"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ kotlin.sourceSets.all {
|
|||||||
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "11"
|
jvmTarget = "11"
|
||||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all"
|
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,9 +62,11 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) {
|
private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) {
|
||||||
val random = Random(0)
|
val random = Random(0)
|
||||||
var sum = 0.0
|
var sum = 0.0
|
||||||
|
val m = HashMap<Symbol, Double>()
|
||||||
|
|
||||||
repeat(times) {
|
repeat(times) {
|
||||||
sum += expr(x to random.nextDouble())
|
m[x] = random.nextDouble()
|
||||||
|
sum += expr(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
blackhole.consume(sum)
|
blackhole.consume(sum)
|
||||||
|
@ -54,7 +54,7 @@ fun Project.addBenchmarkProperties() {
|
|||||||
LocalDateTime.parse(it.name, ISO_DATE_TIME).atZone(ZoneId.systemDefault()).toInstant()
|
LocalDateTime.parse(it.name, ISO_DATE_TIME).atZone(ZoneId.systemDefault()).toInstant()
|
||||||
}
|
}
|
||||||
|
|
||||||
if (resDirectory == null) {
|
if (resDirectory == null || !(resDirectory.resolve("jvm.json")).exists()) {
|
||||||
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
|
"> **Can't find appropriate benchmark data. Try generating readme files after running benchmarks**."
|
||||||
} else {
|
} else {
|
||||||
val reports =
|
val reports =
|
||||||
|
@ -64,9 +64,9 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
|
||||||
kotlinOptions{
|
kotlinOptions {
|
||||||
jvmTarget = "11"
|
jvmTarget = "11"
|
||||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn"
|
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,18 +5,22 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.MstField
|
import space.kscience.kmath.expressions.MstField
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||||
import space.kscience.kmath.expressions.interpret
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
val expr = MstField {
|
val expr = MstField {
|
||||||
x * 2.0 + number(2.0) / x - 16.0
|
x * 2.0 + number(2.0) / x - 16.0
|
||||||
}
|
}.compileToExpression(DoubleField)
|
||||||
|
|
||||||
|
val m = HashMap<Symbol, Double>()
|
||||||
|
|
||||||
repeat(10000000) {
|
repeat(10000000) {
|
||||||
expr.interpret(DoubleField, x to 1.0)
|
m[x] = 1.0
|
||||||
|
expr(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,19 @@ import space.kscience.kmath.operations.bindSymbolOrNull
|
|||||||
*/
|
*/
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
fun AsmBuilder<T>.variablesVisitor(node: MST): Unit = when (node) {
|
||||||
|
is Symbol -> prepareVariable(node.identity)
|
||||||
|
is Unary -> variablesVisitor(node.value)
|
||||||
|
|
||||||
|
is Binary -> {
|
||||||
|
variablesVisitor(node.left)
|
||||||
|
variablesVisitor(node.right)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> Unit
|
||||||
|
}
|
||||||
|
|
||||||
|
fun AsmBuilder<T>.expressionVisitor(node: MST): Unit = when (node) {
|
||||||
is Symbol -> {
|
is Symbol -> {
|
||||||
val symbol = algebra.bindSymbolOrNull(node)
|
val symbol = algebra.bindSymbolOrNull(node)
|
||||||
|
|
||||||
@ -40,39 +52,47 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
|||||||
|
|
||||||
is Unary -> when {
|
is Unary -> when {
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||||
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
|
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)),
|
||||||
|
)
|
||||||
|
|
||||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { expressionVisitor(node.value) }
|
||||||
}
|
}
|
||||||
|
|
||||||
is Binary -> when {
|
is Binary -> when {
|
||||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||||
algebra.binaryOperationFunction(node.operation).invoke(
|
algebra.binaryOperationFunction(node.operation).invoke(
|
||||||
algebra.number((node.left as Numeric).value),
|
algebra.number((node.left as Numeric).value),
|
||||||
algebra.number((node.right as Numeric).value)
|
algebra.number((node.right as Numeric).value),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||||
algebra.leftSideNumberOperationFunction(node.operation)) {
|
algebra.leftSideNumberOperationFunction(node.operation),
|
||||||
visit(node.left)
|
) {
|
||||||
visit(node.right)
|
expressionVisitor(node.left)
|
||||||
|
expressionVisitor(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||||
algebra.rightSideNumberOperationFunction(node.operation)) {
|
algebra.rightSideNumberOperationFunction(node.operation),
|
||||||
visit(node.left)
|
) {
|
||||||
visit(node.right)
|
expressionVisitor(node.left)
|
||||||
|
expressionVisitor(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
else -> buildCall(algebra.binaryOperationFunction(node.operation)) {
|
||||||
visit(node.left)
|
expressionVisitor(node.left)
|
||||||
visit(node.right)
|
expressionVisitor(node.right)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
|
return AsmBuilder<T>(
|
||||||
|
type,
|
||||||
|
buildName(this),
|
||||||
|
{ variablesVisitor(this@compileWith) },
|
||||||
|
{ expressionVisitor(this@compileWith) },
|
||||||
|
).instance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,8 +10,7 @@ import org.objectweb.asm.Opcodes.*
|
|||||||
import org.objectweb.asm.Type.*
|
import org.objectweb.asm.Type.*
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader
|
import space.kscience.kmath.asm.internal.AsmBuilder.ClassLoader
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST
|
|
||||||
import java.lang.invoke.MethodHandles
|
import java.lang.invoke.MethodHandles
|
||||||
import java.lang.invoke.MethodType
|
import java.lang.invoke.MethodType
|
||||||
import java.nio.file.Paths
|
import java.nio.file.Paths
|
||||||
@ -26,13 +25,14 @@ import kotlin.io.path.writeBytes
|
|||||||
*
|
*
|
||||||
* @property T the type of AsmExpression to unwrap.
|
* @property T the type of AsmExpression to unwrap.
|
||||||
* @property className the unique class name of new loaded class.
|
* @property className the unique class name of new loaded class.
|
||||||
* @property callbackAtInvokeL0 the function to apply to this object when generating invoke method, label 0.
|
* @property expressionResultCallback the function to apply to this object when generating expression value.
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
internal class AsmBuilder<T>(
|
internal class AsmBuilder<T>(
|
||||||
classOfT: Class<*>,
|
classOfT: Class<*>,
|
||||||
private val className: String,
|
private val className: String,
|
||||||
private val callbackAtInvokeL0: AsmBuilder<T>.() -> Unit,
|
private val variablesPrepareCallback: AsmBuilder<T>.() -> Unit,
|
||||||
|
private val expressionResultCallback: AsmBuilder<T>.() -> Unit,
|
||||||
) {
|
) {
|
||||||
/**
|
/**
|
||||||
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
* Internal classloader of [AsmBuilder] with alias to define class from byte array.
|
||||||
@ -66,12 +66,17 @@ internal class AsmBuilder<T>(
|
|||||||
*/
|
*/
|
||||||
private lateinit var invokeMethodVisitor: InstructionAdapter
|
private lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Local variables indices are indices of symbols in this list.
|
||||||
|
*/
|
||||||
|
private val argumentsLocals = mutableListOf<String>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||||
*
|
*
|
||||||
* The built instance is cached.
|
* The built instance is cached.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
|
||||||
val instance: Expression<T> by lazy {
|
val instance: Expression<T> by lazy {
|
||||||
val hasConstants: Boolean
|
val hasConstants: Boolean
|
||||||
|
|
||||||
@ -94,26 +99,28 @@ internal class AsmBuilder<T>(
|
|||||||
).instructionAdapter {
|
).instructionAdapter {
|
||||||
invokeMethodVisitor = this
|
invokeMethodVisitor = this
|
||||||
visitCode()
|
visitCode()
|
||||||
val l0 = label()
|
val preparingVariables = label()
|
||||||
callbackAtInvokeL0()
|
variablesPrepareCallback()
|
||||||
|
val expressionResult = label()
|
||||||
|
expressionResultCallback()
|
||||||
areturn(tType)
|
areturn(tType)
|
||||||
val l1 = label()
|
val end = label()
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"this",
|
"this",
|
||||||
classType.descriptor,
|
classType.descriptor,
|
||||||
null,
|
null,
|
||||||
l0,
|
preparingVariables,
|
||||||
l1,
|
end,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"arguments",
|
"arguments",
|
||||||
MAP_TYPE.descriptor,
|
MAP_TYPE.descriptor,
|
||||||
"L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;",
|
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
|
||||||
l0,
|
preparingVariables,
|
||||||
l1,
|
end,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -199,7 +206,7 @@ internal class AsmBuilder<T>(
|
|||||||
val binary = classWriter.toByteArray()
|
val binary = classWriter.toByteArray()
|
||||||
val cls = classLoader.defineClass(className, binary)
|
val cls = classLoader.defineClass(className, binary)
|
||||||
|
|
||||||
if (System.getProperty("space.kscience.communicator.prettyapi.dump.generated.classes") == "1")
|
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
||||||
Paths.get("$className.class").writeBytes(binary)
|
Paths.get("$className.class").writeBytes(binary)
|
||||||
|
|
||||||
val l = MethodHandles.publicLookup()
|
val l = MethodHandles.publicLookup()
|
||||||
@ -256,9 +263,11 @@ internal class AsmBuilder<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke].
|
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
||||||
|
* [loadVariable].
|
||||||
*/
|
*/
|
||||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.run {
|
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
|
||||||
|
if (name in argumentsLocals) return@run
|
||||||
load(1, MAP_TYPE)
|
load(1, MAP_TYPE)
|
||||||
aconst(name)
|
aconst(name)
|
||||||
|
|
||||||
@ -270,8 +279,22 @@ internal class AsmBuilder<T>(
|
|||||||
)
|
)
|
||||||
|
|
||||||
checkcast(tType)
|
checkcast(tType)
|
||||||
|
var idx = argumentsLocals.indexOf(name)
|
||||||
|
|
||||||
|
if (idx == -1) {
|
||||||
|
argumentsLocals += name
|
||||||
|
idx = argumentsLocals.lastIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
store(2 + idx, tType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
||||||
|
* with [prepareVariable] first.
|
||||||
|
*/
|
||||||
|
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tType)
|
||||||
|
|
||||||
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
|
inline fun buildCall(function: Function<T>, parameters: AsmBuilder<T>.() -> Unit) {
|
||||||
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(parameters, InvocationKind.EXACTLY_ONCE) }
|
||||||
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }
|
val `interface` = function.javaClass.interfaces.first { Function::class.java in it.interfaces }
|
||||||
|
@ -5,10 +5,21 @@ plugins {
|
|||||||
// id("com.xcporter.metaview") version "0.0.5"
|
// id("com.xcporter.metaview") version "0.0.5"
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin {
|
||||||
commonMain {
|
jvm {
|
||||||
dependencies {
|
compilations.all {
|
||||||
api(project(":kmath-memory"))
|
kotlinOptions {
|
||||||
|
freeCompilerArgs =
|
||||||
|
freeCompilerArgs + "-Xjvm-default=all" + "-Xopt-in=kotlin.RequiresOptIn" + "-Xlambdas=indy"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceSets {
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-memory"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ public fun interface Expression<T> {
|
|||||||
*
|
*
|
||||||
* @return a value.
|
* @return a value.
|
||||||
*/
|
*/
|
||||||
public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
public operator fun <T> Expression<T>.invoke(): T = this(emptyMap())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
@ -38,7 +38,13 @@ public operator fun <T> Expression<T>.invoke(): T = invoke(emptyMap())
|
|||||||
* @return a value.
|
* @return a value.
|
||||||
*/
|
*/
|
||||||
@JvmName("callBySymbol")
|
@JvmName("callBySymbol")
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = invoke(mapOf(*pairs))
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T = this(
|
||||||
|
when (pairs.size) {
|
||||||
|
0 -> emptyMap()
|
||||||
|
1 -> mapOf(pairs[0])
|
||||||
|
else -> hashMapOf(*pairs)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
@ -47,8 +53,21 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<Symbol, T>): T =
|
|||||||
* @return a value.
|
* @return a value.
|
||||||
*/
|
*/
|
||||||
@JvmName("callByString")
|
@JvmName("callByString")
|
||||||
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T = this(
|
||||||
invoke(mapOf(*pairs).mapKeys { StringSymbol(it.key) })
|
when (pairs.size) {
|
||||||
|
0 -> emptyMap()
|
||||||
|
|
||||||
|
1 -> {
|
||||||
|
val (k, v) = pairs[0]
|
||||||
|
mapOf(StringSymbol(k) to v)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> hashMapOf(*Array<Pair<Symbol, T>>(pairs.size) {
|
||||||
|
val (k, v) = pairs[it]
|
||||||
|
StringSymbol(k) to v
|
||||||
|
})
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user