Very experimental WASM code generation by MST in contexts of Int and Real #158

Closed
CommanderTvis wants to merge 16 commits from feature/binaryen into dev
4 changed files with 57 additions and 24 deletions
Showing only changes of commit 8de67157e9 - Show all commits

View File

@ -5,17 +5,13 @@ plugins {
kotlin.js { kotlin.js {
nodejs { // or `browser` nodejs { // or `browser`
testTask { testTask {
useMocha { useMocha().timeout = "0"
timeout = "0"// mochaTimeout here as string
}
} }
} }
browser { // or `browser` browser {
testTask { testTask {
useMocha { useMocha().timeout = "0"
timeout = "0"// mochaTimeout here as string
}
} }
} }
} }
@ -37,7 +33,7 @@ kotlin.sourceSets {
jsMain { jsMain {
dependencies { dependencies {
implementation(npm("binaryen", "98.0.0")) implementation(npm("binaryen", "98.0.0-nightly.20201113"))
implementation(npm("js-base64", "3.6.0")) implementation(npm("js-base64", "3.6.0"))
implementation(npm("webassembly", "0.11.0")) implementation(npm("webassembly", "0.11.0"))
} }

View File

@ -11,7 +11,11 @@ import binaryen.Module as BinaryenModule
private val spreader = eval("(obj, args) => obj(...args)") private val spreader = eval("(obj, args) => obj(...args)")
@Suppress("UnsafeCastFromDynamic") @Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: Algebra<T>) where T : Number { internal sealed class WasmBuilder<T>(
val binaryenType: Type,
val kmathAlgebra: Algebra<T>,
val target: MST
) where T : Number {
val keys: MutableList<String> = mutableListOf() val keys: MutableList<String> = mutableListOf()
lateinit var ctx: BinaryenModule lateinit var ctx: BinaryenModule
@ -49,10 +53,10 @@ internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: A
is MST.Binary -> visitBinary(mst) is MST.Binary -> visitBinary(mst)
} }
fun compile(mst: MST): Expression<T> { val instance by lazy {
val bin = with(createModule()) { val c = WasmModule(with(createModule()) {
ctx = this ctx = this
val expr = visit(mst) val expr = visit(target)
addFunction( addFunction(
"executable", "executable",
@ -68,20 +72,20 @@ internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: A
val res = emitBinary() val res = emitBinary()
dispose() dispose()
res res
} })
val c = WasmModule(bin)
val i = Instance(c, js("{}") as Any) val i = Instance(c, js("{}") as Any)
val symbols = keys.map(::StringSymbol)
keys.clear()
return Expression { args -> Expression<T> { args ->
val params = keys.map(::StringSymbol).map(args::getValue).toTypedArray() val params = symbols.map(args::getValue).toTypedArray()
keys.clear()
spreader(i.exports.asDynamic().executable, params) as T spreader(i.exports.asDynamic().executable, params) as T
} }
} }
} }
internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) { internal class RealWasmBuilder(target: MST) : WasmBuilder<Double>(f64, RealField, target) {
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions) override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions)
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value) override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value)
@ -117,7 +121,7 @@ internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) {
} }
} }
internal class IntWasmBuilder : WasmBuilder<Int>(i32, IntRing) { internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, target) {
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value) override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value)
override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) { override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) {

View File

@ -1,28 +1,61 @@
package kscience.kmath.ast package kscience.kmath.ast
import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kotlin.random.Random
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.time.measureTime
internal class Test { internal class Test {
@Test @Test
fun int() { fun int() {
val res = IntWasmBuilder().compile(MstRing { number(100000000) + number(10000000) })() val res = IntWasmBuilder(MstRing { number(100000000) + number(10000000) }).instance()
assertEquals(110000000, res) assertEquals(110000000, res)
} }
@Test @Test
fun real() { fun real() {
val res = RealWasmBuilder().compile(MstExtendedField { number(100000000) + number(2).pow(10) })() val res = RealWasmBuilder(MstExtendedField { number(100000000) + number(2).pow(10) }).instance()
assertEquals(100001024.0, res) assertEquals(100001024.0, res)
} }
@Test @Test
fun argsPassing() { fun argsPassing() {
val res = RealWasmBuilder() val res = RealWasmBuilder(MstExtendedField { symbol("y") + symbol("x").pow(10) })
.compile(MstExtendedField { symbol("y") + symbol("x").pow(10) })("x" to 2.0, "y" to 100000000.0) .instance("x" to 2.0, "y" to 100000000.0)
assertEquals(100001024.0, res) assertEquals(100001024.0, res)
} }
@Test
fun powFunction() {
val expr = RealWasmBuilder(MstExtendedField { symbol("x").pow(1.0 / 6.0) }).instance
assertEquals(0.9730585187140817, expr("x" to 0.8488554755054833))
}
@Test
fun manyRuns() {
println("Compiled")
val times = 1_000_000
var rng = Random(0)
var sum1 = 0.0
var sum2 = 0.0
measureTime {
val res = RealWasmBuilder(MstExtendedField { symbol("x").pow(1.0 / 6.0) }).instance
repeat(times) { sum1 += res("x" to rng.nextDouble()) }
}.also(::println)
println("MST")
rng = Random(0)
measureTime {
val res = RealField.mstInExtendedField { symbol("x").pow(1.0 / 6.0) }
repeat(times) { sum2 += res("x" to rng.nextDouble()) }
}.also(::println)
assertEquals(sum1, sum2)
}
} }

View File

@ -49,7 +49,7 @@ public interface NumericAlgebra<T> : Algebra<T> {
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number]. * Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
*/ */
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
leftSideNumberOperation(operation, right, left) binaryOperation(operation, left, number(right))
} }
/** /**