Add justCalculate benchmark, some minor refactorings

This commit is contained in:
Iaroslav Postovalov 2021-05-26 20:24:29 +07:00
parent 4810f2e63e
commit 46bf66c8ee
12 changed files with 113 additions and 43 deletions

View File

@ -14,22 +14,51 @@ import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.math.sin
import kotlin.random.Random import kotlin.random.Random
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class ExpressionsInterpretersBenchmark { internal class ExpressionsInterpretersBenchmark {
/**
* Benchmark case for [Expression] created with [expressionInExtendedField].
*/
@Benchmark @Benchmark
fun functionalExpression(blackhole: Blackhole) = invokeAndSum(functional, blackhole) fun functionalExpression(blackhole: Blackhole) = invokeAndSum(functional, blackhole)
/**
* Benchmark case for [Expression] created with [toExpression].
*/
@Benchmark @Benchmark
fun mstExpression(blackhole: Blackhole) = invokeAndSum(mst, blackhole) fun mstExpression(blackhole: Blackhole) = invokeAndSum(mst, blackhole)
/**
* Benchmark case for [Expression] created with [compileToExpression].
*/
@Benchmark @Benchmark
fun asmExpression(blackhole: Blackhole) = invokeAndSum(asm, blackhole) fun asmExpression(blackhole: Blackhole) = invokeAndSum(asm, blackhole)
/**
* Benchmark case for [Expression] implemented manually with `kotlin.math` functions.
*/
@Benchmark @Benchmark
fun rawExpression(blackhole: Blackhole) = invokeAndSum(raw, blackhole) fun rawExpression(blackhole: Blackhole) = invokeAndSum(raw, blackhole)
/**
* Benchmark case for direct computation w/o [Expression].
*/
@Benchmark
fun justCalculate(blackhole: Blackhole) {
val random = Random(0)
var sum = 0.0
repeat(times) {
val x = random.nextDouble()
sum += x * 2.0 + 2.0 / x - 16.0 / sin(x)
}
blackhole.consume(sum)
}
private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) { private fun invokeAndSum(expr: Expression<Double>, blackhole: Blackhole) {
val random = Random(0) val random = Random(0)
var sum = 0.0 var sum = 0.0
@ -42,23 +71,24 @@ internal class ExpressionsInterpretersBenchmark {
} }
private companion object { private companion object {
private val x: Symbol by symbol private val x by symbol
private val algebra: DoubleField = DoubleField private val algebra = DoubleField
private const val times = 1_000_000 private const val times = 1_000_000
private val functional: Expression<Double> = DoubleField.expressionInExtendedField { private val functional = DoubleField.expressionInExtendedField {
bindSymbol(x) * number(2.0) + number(2.0) / bindSymbol(x) - number(16.0) / sin(bindSymbol(x)) bindSymbol(x) * number(2.0) + number(2.0) / bindSymbol(x) - number(16.0) / sin(bindSymbol(x))
} }
private val node = MstExtendedField { private val node = MstExtendedField {
bindSymbol(x) * 2.0 + number(2.0) / bindSymbol(x) - number(16.0) / sin(bindSymbol(x)) x * 2.0 + number(2.0) / x - number(16.0) / sin(x)
} }
private val mst: Expression<Double> = node.toExpression(DoubleField) private val mst = node.toExpression(DoubleField)
private val asm: Expression<Double> = node.compileToExpression(DoubleField) private val asm = node.compileToExpression(DoubleField)
private val raw: Expression<Double> = Expression { args -> private val raw = Expression<Double> { args ->
args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 / kotlin.math.sin(args.getValue(x)) val x = args[x]!!
x * 2.0 + 2.0 / x - 16.0 / sin(x)
} }
} }
} }

View File

@ -9,12 +9,10 @@ import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.interpret import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
fun main() { fun main() {
val expr = MstField { val expr = MstField {
val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0 x * 2.0 + number(2.0) / x - 16.0
} }

View File

@ -10,7 +10,7 @@ Performance and visualization extensions to MST API.
## Artifact: ## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-11`. The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-13`.
**Gradle:** **Gradle:**
```gradle ```gradle
@ -20,7 +20,7 @@ repositories {
} }
dependencies { dependencies {
implementation 'space.kscience:kmath-ast:0.3.0-dev-11' implementation 'space.kscience:kmath-ast:0.3.0-dev-13'
} }
``` ```
**Gradle Kotlin DSL:** **Gradle Kotlin DSL:**
@ -31,7 +31,7 @@ repositories {
} }
dependencies { dependencies {
implementation("space.kscience:kmath-ast:0.3.0-dev-11") implementation("space.kscience:kmath-ast:0.3.0-dev-13")
} }
``` ```
@ -45,11 +45,12 @@ special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following builder:
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.* import space.kscience.kmath.asm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField) MstField { x + 2 }.compileToExpression(DoubleField)
``` ```
... leads to generation of bytecode, which can be decompiled to the following Java class: ... leads to generation of bytecode, which can be decompiled to the following Java class:
@ -89,11 +90,12 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
A similar feature is also available on JS. A similar feature is also available on JS.
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.estree.* import space.kscience.kmath.estree.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField) MstField { x + 2 }.compileToExpression(DoubleField)
``` ```
The code above returns expression implemented with such a JS function: The code above returns expression implemented with such a JS function:
@ -108,11 +110,12 @@ JS also supports very experimental expression optimization with [WebAssembly](ht
Currently, only expressions inside `DoubleField` and `IntRing` are supported. Currently, only expressions inside `DoubleField` and `IntRing` are supported.
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.wasm.* import space.kscience.kmath.wasm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField) MstField { x + 2 }.compileToExpression(DoubleField)
``` ```
An example of emitted Wasm IR in the form of WAT: An example of emitted Wasm IR in the form of WAT:

View File

@ -16,11 +16,12 @@ special implementation of `Expression<T>` with implemented `invoke` function.
For example, the following builder: For example, the following builder:
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.asm.* import space.kscience.kmath.asm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField) MstField { x + 2 }.compileToExpression(DoubleField)
``` ```
... leads to generation of bytecode, which can be decompiled to the following Java class: ... leads to generation of bytecode, which can be decompiled to the following Java class:
@ -60,11 +61,12 @@ public final class AsmCompiledExpression_45045_0 implements Expression<Double> {
A similar feature is also available on JS. A similar feature is also available on JS.
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.estree.* import space.kscience.kmath.estree.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField) MstField { x + 2 }.compileToExpression(DoubleField)
``` ```
The code above returns expression implemented with such a JS function: The code above returns expression implemented with such a JS function:
@ -79,11 +81,12 @@ JS also supports very experimental expression optimization with [WebAssembly](ht
Currently, only expressions inside `DoubleField` and `IntRing` are supported. Currently, only expressions inside `DoubleField` and `IntRing` are supported.
```kotlin ```kotlin
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.wasm.* import space.kscience.kmath.wasm.*
MstField { bindSymbol("x") + 2 }.compileToExpression(DoubleField) MstField { x + 2 }.compileToExpression(DoubleField)
``` ```
An example of emitted Wasm IR in the form of WAT: An example of emitted Wasm IR in the form of WAT:

View File

@ -22,7 +22,7 @@ internal class TestCompilerConsistencyWithInterpreter {
val mst = MstRing { val mst = MstRing {
binaryOperationFunction("+")( binaryOperationFunction("+")(
unaryOperationFunction("+")( unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale( (x - (2.toByte() + (scale(
add(number(1), number(1)), add(number(1), number(1)),
2.0, 2.0,
) + 1.toByte()))) * 3.0 - 1.toByte() ) + 1.toByte()))) * 3.0 - 1.toByte()
@ -42,7 +42,7 @@ internal class TestCompilerConsistencyWithInterpreter {
fun doubleField() = runCompilerTest { fun doubleField() = runCompilerTest {
val mst = MstField { val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 (3.0 - (x + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one, number(1) / 2 + number(2.0) * one,
) + zero ) + zero

View File

@ -47,19 +47,19 @@ internal class TestCompilerOperations {
@Test @Test
fun testSubtract() = runCompilerTest { fun testSubtract() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) - bindSymbol(x) }.compileToExpression(DoubleField) val expr = MstExtendedField { x - x }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0)) assertEquals(0.0, expr(x to 2.0))
} }
@Test @Test
fun testDivide() = runCompilerTest { fun testDivide() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) / bindSymbol(x) }.compileToExpression(DoubleField) val expr = MstExtendedField { x / x }.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0)) assertEquals(1.0, expr(x to 2.0))
} }
@Test @Test
fun testPower() = runCompilerTest { fun testPower() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) pow 2 }.compileToExpression(DoubleField) val expr = MstExtendedField { x pow 2 }.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0)) assertEquals(4.0, expr(x to 2.0))
} }
} }

View File

@ -18,13 +18,13 @@ import kotlin.test.assertFailsWith
internal class TestCompilerVariables { internal class TestCompilerVariables {
@Test @Test
fun testVariable() = runCompilerTest { fun testVariable() = runCompilerTest {
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing) val expr = MstRing { x }.compileToExpression(IntRing)
assertEquals(1, expr(x to 1)) assertEquals(1, expr(x to 1))
} }
@Test @Test
fun testUndefinedVariableFails() = runCompilerTest { fun testUndefinedVariableFails() = runCompilerTest {
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing) val expr = MstRing { x }.compileToExpression(IntRing)
assertFailsWith<NoSuchElementException> { expr() } assertFailsWith<NoSuchElementException> { expr() }
} }
} }

View File

@ -27,11 +27,7 @@ internal sealed class WasmBuilder<T>(
lateinit var ctx: BinaryenModule lateinit var ctx: BinaryenModule
open fun visitSymbolic(mst: Symbol): ExpressionRef { open fun visitSymbolic(mst: Symbol): ExpressionRef {
try { algebra.bindSymbolOrNull(mst)?.let { return visitNumeric(Numeric(it)) }
algebra.bindSymbol(mst)
} catch (ignored: Throwable) {
null
}?.let { return visitNumeric(Numeric(it)) }
var idx = keys.indexOf(mst) var idx = keys.indexOf(mst)

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.ast
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.math.sin import kotlin.math.sin
@ -17,18 +16,19 @@ import kotlin.time.measureTime
import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression
import space.kscience.kmath.wasm.compileToExpression as wasmCompileToExpression import space.kscience.kmath.wasm.compileToExpression as wasmCompileToExpression
// TODO move to benchmarks when https://github.com/Kotlin/kotlinx-benchmark/pull/38 or similar feature is merged
internal class TestExecutionTime { internal class TestExecutionTime {
private companion object { private companion object {
private const val times = 1_000_000 private const val times = 1_000_000
private val x by symbol private val x by symbol
private val algebra: ExtendedField<Double> = DoubleField private val algebra = DoubleField
private val functional = DoubleField.expressionInExtendedField { private val functional = DoubleField.expressionInExtendedField {
bindSymbol(x) * const(2.0) + const(2.0) / bindSymbol(x) - const(16.0) / sin(bindSymbol(x)) bindSymbol(x) * const(2.0) + const(2.0) / bindSymbol(x) - const(16.0) / sin(bindSymbol(x))
} }
private val node = MstExtendedField { private val node = MstExtendedField {
bindSymbol(x) * number(2.0) + number(2.0) / bindSymbol(x) - number(16.0) / sin(bindSymbol(x)) x * number(2.0) + number(2.0) / x - number(16.0) / sin(x)
} }
private val mst = node.toExpression(DoubleField) private val mst = node.toExpression(DoubleField)
@ -43,7 +43,13 @@ internal class TestExecutionTime {
// }; // };
private val raw = Expression<Double> { args -> private val raw = Expression<Double> { args ->
args.getValue(x) * 2.0 + 2.0 / args.getValue(x) - 16.0 / sin(args.getValue(x)) val x = args[x]!!
x * 2.0 + 2.0 / x - 16.0 / sin(x)
}
private val justCalculate = { args: dynamic ->
val x = args[x].unsafeCast<Double>()
x * 2.0 + 2.0 / x - 16.0 / sin(x)
} }
} }
@ -51,21 +57,56 @@ internal class TestExecutionTime {
println(name) println(name)
val rng = Random(0) val rng = Random(0)
var sum = 0.0 var sum = 0.0
measureTime { repeat(times) { sum += expr(x to rng.nextDouble()) } }.also(::println) measureTime {
repeat(times) { sum += expr(x to rng.nextDouble()) }
}.also(::println)
} }
/**
* [Expression] created with [expressionInExtendedField].
*/
@Test @Test
fun functionalExpression() = invokeAndSum("functional", functional) fun functionalExpression() = invokeAndSum("functional", functional)
/**
* [Expression] created with [mstExpression].
*/
@Test @Test
fun mstExpression() = invokeAndSum("mst", mst) fun mstExpression() = invokeAndSum("mst", mst)
/**
* [Expression] created with [wasmCompileToExpression].
*/
@Test @Test
fun wasmExpression() = invokeAndSum("wasm", wasm) fun wasmExpression() = invokeAndSum("wasm", wasm)
/**
* [Expression] created with [estreeCompileToExpression].
*/
@Test @Test
fun estreeExpression() = invokeAndSum("estree", wasm) fun estreeExpression() = invokeAndSum("estree", wasm)
/**
* [Expression] implemented manually with `kotlin.math`.
*/
@Test @Test
fun rawExpression() = invokeAndSum("raw", raw) fun rawExpression() = invokeAndSum("raw", raw)
/**
* Direct computation w/o [Expression].
*/
@Test
fun justCalculateExpression() {
println("justCalculate")
val rng = Random(0)
var sum = 0.0
measureTime {
repeat(times) {
val arg = rng.nextDouble()
val o = js("{}")
o["x"] = arg
sum += justCalculate(o)
}
}.also(::println)
}
} }

View File

@ -31,7 +31,7 @@ internal class TestWasmSpecific {
@Test @Test
fun argsPassing() { fun argsPassing() {
val res = MstExtendedField { bindSymbol(y) + bindSymbol(x).pow(10) }.compile( val res = MstExtendedField { y + x.pow(10) }.compile(
DoubleField, DoubleField,
x to 2.0, x to 2.0,
y to 100000000.0, y to 100000000.0,
@ -42,7 +42,7 @@ internal class TestWasmSpecific {
@Test @Test
fun powFunction() { fun powFunction() {
val expr = MstExtendedField { bindSymbol(x).pow(1.0 / 6.0) }.compileToExpression(DoubleField) val expr = MstExtendedField { x.pow(1.0 / 6.0) }.compileToExpression(DoubleField)
assertEquals(0.9730585187140817, expr(x to 0.8488554755054833)) assertEquals(0.9730585187140817, expr(x to 0.8488554755054833))
} }

View File

@ -63,7 +63,7 @@ internal fun MethodVisitor.label(): Label = Label().also(::visitLabel)
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
internal tailrec fun buildName(mst: MST, collision: Int = 0): String { internal tailrec fun buildName(mst: MST, collision: Int = 0): String {
val name = "kscience.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision" val name = "space.kscience.kmath.asm.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try { try {
Class.forName(name) Class.forName(name)

View File

@ -20,8 +20,7 @@ import kotlin.test.fail
internal class AdaptingTests { internal class AdaptingTests {
@Test @Test
fun symbol() { fun symbol() {
val c1 = MstNumericAlgebra.bindSymbol(x.identity) assertEquals(x.identity, x.toSVar<KMathNumber<Double, DoubleField>>().name)
assertEquals(x.identity, c1.toSVar<KMathNumber<Double, DoubleField>>().name)
val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, DoubleField>>() val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail() if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
} }
@ -46,7 +45,7 @@ internal class AdaptingTests {
@Test @Test
fun simpleFunctionDerivative() { fun simpleFunctionDerivative() {
val xSVar = MstNumericAlgebra.bindSymbol(x.identity).toSVar<KMathNumber<Double, DoubleField>>() val xSVar = x.toSVar<KMathNumber<Double, DoubleField>>()
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>() val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = quadratic.d(xSVar).toMst().compileToExpression(DoubleField) val actualDerivative = quadratic.d(xSVar).toMst().compileToExpression(DoubleField)
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField) val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
@ -55,7 +54,7 @@ internal class AdaptingTests {
@Test @Test
fun moreComplexDerivative() { fun moreComplexDerivative() {
val xSVar = MstNumericAlgebra.bindSymbol(x.identity).toSVar<KMathNumber<Double, DoubleField>>() val xSVar = x.toSVar<KMathNumber<Double, DoubleField>>()
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>() val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = composition.d(xSVar).toMst().compileToExpression(DoubleField) val actualDerivative = composition.d(xSVar).toMst().compileToExpression(DoubleField)