forked from kscience/kmath
Add example of new AST API
This commit is contained in:
parent
57bdee4936
commit
54069fd37e
@ -19,7 +19,8 @@ repositories {
|
|||||||
sourceSets.register("benchmarks")
|
sourceSets.register("benchmarks")
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
// implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
|
implementation(project(":kmath-ast-kotlingrad"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
|
@ -1,70 +1,80 @@
|
|||||||
package kscience.kmath.ast
|
package kscience.kmath.ast
|
||||||
//
|
|
||||||
//import kscience.kmath.asm.compile
|
import kscience.kmath.asm.compile
|
||||||
//import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
//import kscience.kmath.expressions.expressionInField
|
import kscience.kmath.expressions.expressionInField
|
||||||
//import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
//import kscience.kmath.operations.Field
|
import kscience.kmath.operations.Field
|
||||||
//import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
//import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
//import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
//
|
|
||||||
//class ExpressionsInterpretersBenchmark {
|
internal class ExpressionsInterpretersBenchmark {
|
||||||
// private val algebra: Field<Double> = RealField
|
private val algebra: Field<Double> = RealField
|
||||||
// fun functionalExpression() {
|
fun functionalExpression() {
|
||||||
// val expr = algebra.expressionInField {
|
val expr = algebra.expressionInField {
|
||||||
// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// fun mstExpression() {
|
fun mstExpression() {
|
||||||
// val expr = algebra.mstInField {
|
val expr = algebra.mstInField {
|
||||||
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// fun asmExpression() {
|
fun asmExpression() {
|
||||||
// val expr = algebra.mstInField {
|
val expr = algebra.mstInField {
|
||||||
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
// }.compile()
|
}.compile()
|
||||||
//
|
|
||||||
// invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// private fun invokeAndSum(expr: Expression<Double>) {
|
private fun invokeAndSum(expr: Expression<Double>) {
|
||||||
// val random = Random(0)
|
val random = Random(0)
|
||||||
// var sum = 0.0
|
var sum = 0.0
|
||||||
//
|
|
||||||
// repeat(1000000) {
|
repeat(1000000) {
|
||||||
// sum += expr("x" to random.nextDouble())
|
sum += expr("x" to random.nextDouble())
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// println(sum)
|
println(sum)
|
||||||
// }
|
}
|
||||||
//}
|
}
|
||||||
//
|
|
||||||
//fun main() {
|
/**
|
||||||
// val benchmark = ExpressionsInterpretersBenchmark()
|
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
|
||||||
//
|
* core FunctionalExpressions API.
|
||||||
// val fe = measureTimeMillis {
|
*
|
||||||
// benchmark.functionalExpression()
|
* The expected rating is:
|
||||||
// }
|
*
|
||||||
//
|
* 1. ASM.
|
||||||
// println("fe=$fe")
|
* 2. MST.
|
||||||
//
|
* 3. FE.
|
||||||
// val mst = measureTimeMillis {
|
*/
|
||||||
// benchmark.mstExpression()
|
fun main() {
|
||||||
// }
|
val benchmark = ExpressionsInterpretersBenchmark()
|
||||||
//
|
|
||||||
// println("mst=$mst")
|
val fe = measureTimeMillis {
|
||||||
//
|
benchmark.functionalExpression()
|
||||||
// val asm = measureTimeMillis {
|
}
|
||||||
// benchmark.asmExpression()
|
|
||||||
// }
|
println("fe=$fe")
|
||||||
//
|
|
||||||
// println("asm=$asm")
|
val mst = measureTimeMillis {
|
||||||
//}
|
benchmark.mstExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("mst=$mst")
|
||||||
|
|
||||||
|
val asm = measureTimeMillis {
|
||||||
|
benchmark.asmExpression()
|
||||||
|
}
|
||||||
|
|
||||||
|
println("asm=$asm")
|
||||||
|
}
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
package kscience.kmath.ast
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.DoublePrecision
|
||||||
|
import kscience.kmath.asm.compile
|
||||||
|
import kscience.kmath.ast.kotlingrad.mst
|
||||||
|
import kscience.kmath.ast.kotlingrad.sfun
|
||||||
|
import kscience.kmath.ast.kotlingrad.svar
|
||||||
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
||||||
|
* valid derivative.
|
||||||
|
*/
|
||||||
|
fun main() {
|
||||||
|
val proto = DoublePrecision.prototype
|
||||||
|
val x by MstAlgebra.symbol("x").svar(proto)
|
||||||
|
val quadratic = "x^2-4*x-44".parseMath().sfun(proto)
|
||||||
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).mst()).compile()
|
||||||
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user