Add example of new AST API

This commit is contained in:
Iaroslav Postovalov 2020-10-12 22:42:34 +07:00
parent 57bdee4936
commit 54069fd37e
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
3 changed files with 103 additions and 70 deletions

View File

@ -19,7 +19,8 @@ repositories {
sourceSets.register("benchmarks")
dependencies {
// implementation(project(":kmath-ast"))
implementation(project(":kmath-ast"))
implementation(project(":kmath-ast-kotlingrad"))
implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons"))

View File

@ -1,70 +1,80 @@
package kscience.kmath.ast
//
//import kscience.kmath.asm.compile
//import kscience.kmath.expressions.Expression
//import kscience.kmath.expressions.expressionInField
//import kscience.kmath.expressions.invoke
//import kscience.kmath.operations.Field
//import kscience.kmath.operations.RealField
//import kotlin.random.Random
//import kotlin.system.measureTimeMillis
//
//class ExpressionsInterpretersBenchmark {
// private val algebra: Field<Double> = RealField
// fun functionalExpression() {
// val expr = algebra.expressionInField {
// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
// }
//
// invokeAndSum(expr)
// }
//
// fun mstExpression() {
// val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// }
//
// invokeAndSum(expr)
// }
//
// fun asmExpression() {
// val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// }.compile()
//
// invokeAndSum(expr)
// }
//
// private fun invokeAndSum(expr: Expression<Double>) {
// val random = Random(0)
// var sum = 0.0
//
// repeat(1000000) {
// sum += expr("x" to random.nextDouble())
// }
//
// println(sum)
// }
//}
//
//fun main() {
// val benchmark = ExpressionsInterpretersBenchmark()
//
// val fe = measureTimeMillis {
// benchmark.functionalExpression()
// }
//
// println("fe=$fe")
//
// val mst = measureTimeMillis {
// benchmark.mstExpression()
// }
//
// println("mst=$mst")
//
// val asm = measureTimeMillis {
// benchmark.asmExpression()
// }
//
// println("asm=$asm")
//}
import kscience.kmath.asm.compile
import kscience.kmath.expressions.Expression
import kscience.kmath.expressions.expressionInField
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.Field
import kscience.kmath.operations.RealField
import kotlin.random.Random
import kotlin.system.measureTimeMillis
internal class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField
fun functionalExpression() {
val expr = algebra.expressionInField {
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
}
invokeAndSum(expr)
}
fun mstExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}
invokeAndSum(expr)
}
fun asmExpression() {
val expr = algebra.mstInField {
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
}.compile()
invokeAndSum(expr)
}
private fun invokeAndSum(expr: Expression<Double>) {
val random = Random(0)
var sum = 0.0
repeat(1000000) {
sum += expr("x" to random.nextDouble())
}
println(sum)
}
}
/**
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
* core FunctionalExpressions API.
*
* The expected rating is:
*
* 1. ASM.
* 2. MST.
* 3. FE.
*/
fun main() {
val benchmark = ExpressionsInterpretersBenchmark()
val fe = measureTimeMillis {
benchmark.functionalExpression()
}
println("fe=$fe")
val mst = measureTimeMillis {
benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -0,0 +1,22 @@
package kscience.kmath.ast
import edu.umontreal.kotlingrad.experimental.DoublePrecision
import kscience.kmath.asm.compile
import kscience.kmath.ast.kotlingrad.mst
import kscience.kmath.ast.kotlingrad.sfun
import kscience.kmath.ast.kotlingrad.svar
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
/**
* In this example, x^2-4*x-44 function is differentiated with Kotlin, and the autodiff result is compared with
* valid derivative.
*/
fun main() {
val proto = DoublePrecision.prototype
val x by MstAlgebra.symbol("x").svar(proto)
val quadratic = "x^2-4*x-44".parseMath().sfun(proto)
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
}