Fix bugs, add test for massive run

This commit is contained in:
Iaroslav Postovalov 2020-11-14 21:29:28 +07:00
parent ca219fa91b
commit 8de67157e9
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
4 changed files with 57 additions and 24 deletions

View File

@ -5,17 +5,13 @@ plugins {
kotlin.js {
nodejs { // or `browser`
testTask {
useMocha {
timeout = "0"// mochaTimeout here as string
}
useMocha().timeout = "0"
}
}
browser { // or `browser`
browser {
testTask {
useMocha {
timeout = "0"// mochaTimeout here as string
}
useMocha().timeout = "0"
}
}
}
@ -37,7 +33,7 @@ kotlin.sourceSets {
jsMain {
dependencies {
implementation(npm("binaryen", "98.0.0"))
implementation(npm("binaryen", "98.0.0-nightly.20201113"))
implementation(npm("js-base64", "3.6.0"))
implementation(npm("webassembly", "0.11.0"))
}

View File

@ -11,7 +11,11 @@ import binaryen.Module as BinaryenModule
private val spreader = eval("(obj, args) => obj(...args)")
@Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: Algebra<T>) where T : Number {
internal sealed class WasmBuilder<T>(
val binaryenType: Type,
val kmathAlgebra: Algebra<T>,
val target: MST
) where T : Number {
val keys: MutableList<String> = mutableListOf()
lateinit var ctx: BinaryenModule
@ -49,10 +53,10 @@ internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: A
is MST.Binary -> visitBinary(mst)
}
fun compile(mst: MST): Expression<T> {
val bin = with(createModule()) {
val instance by lazy {
val c = WasmModule(with(createModule()) {
ctx = this
val expr = visit(mst)
val expr = visit(target)
addFunction(
"executable",
@ -68,20 +72,20 @@ internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: A
val res = emitBinary()
dispose()
res
}
})
val c = WasmModule(bin)
val i = Instance(c, js("{}") as Any)
val symbols = keys.map(::StringSymbol)
keys.clear()
return Expression { args ->
val params = keys.map(::StringSymbol).map(args::getValue).toTypedArray()
keys.clear()
Expression<T> { args ->
val params = symbols.map(args::getValue).toTypedArray()
spreader(i.exports.asDynamic().executable, params) as T
}
}
}
internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) {
internal class RealWasmBuilder(target: MST) : WasmBuilder<Double>(f64, RealField, target) {
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions)
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value)
@ -117,7 +121,7 @@ internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) {
}
}
internal class IntWasmBuilder : WasmBuilder<Int>(i32, IntRing) {
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, target) {
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value)
override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) {

View File

@ -1,28 +1,61 @@
package kscience.kmath.ast
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.time.measureTime
internal class Test {
@Test
fun int() {
val res = IntWasmBuilder().compile(MstRing { number(100000000) + number(10000000) })()
val res = IntWasmBuilder(MstRing { number(100000000) + number(10000000) }).instance()
assertEquals(110000000, res)
}
@Test
fun real() {
val res = RealWasmBuilder().compile(MstExtendedField { number(100000000) + number(2).pow(10) })()
val res = RealWasmBuilder(MstExtendedField { number(100000000) + number(2).pow(10) }).instance()
assertEquals(100001024.0, res)
}
@Test
fun argsPassing() {
val res = RealWasmBuilder()
.compile(MstExtendedField { symbol("y") + symbol("x").pow(10) })("x" to 2.0, "y" to 100000000.0)
val res = RealWasmBuilder(MstExtendedField { symbol("y") + symbol("x").pow(10) })
.instance("x" to 2.0, "y" to 100000000.0)
assertEquals(100001024.0, res)
}
@Test
fun powFunction() {
val expr = RealWasmBuilder(MstExtendedField { symbol("x").pow(1.0 / 6.0) }).instance
assertEquals(0.9730585187140817, expr("x" to 0.8488554755054833))
}
@Test
fun manyRuns() {
println("Compiled")
val times = 1_000_000
var rng = Random(0)
var sum1 = 0.0
var sum2 = 0.0
measureTime {
val res = RealWasmBuilder(MstExtendedField { symbol("x").pow(1.0 / 6.0) }).instance
repeat(times) { sum1 += res("x" to rng.nextDouble()) }
}.also(::println)
println("MST")
rng = Random(0)
measureTime {
val res = RealField.mstInExtendedField { symbol("x").pow(1.0 / 6.0) }
repeat(times) { sum2 += res("x" to rng.nextDouble()) }
}.also(::println)
assertEquals(sum1, sum2)
}
}

View File

@ -49,7 +49,7 @@ public interface NumericAlgebra<T> : Algebra<T> {
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
*/
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left)
binaryOperation(operation, left, number(right))
}
/**