Very experimental WASM code generation by MST in contexts of Int and Real #158
@ -5,17 +5,13 @@ plugins {
|
|||||||
kotlin.js {
|
kotlin.js {
|
||||||
nodejs { // or `browser`
|
nodejs { // or `browser`
|
||||||
testTask {
|
testTask {
|
||||||
useMocha {
|
useMocha().timeout = "0"
|
||||||
timeout = "0"// mochaTimeout here as string
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
browser { // or `browser`
|
browser {
|
||||||
testTask {
|
testTask {
|
||||||
useMocha {
|
useMocha().timeout = "0"
|
||||||
timeout = "0"// mochaTimeout here as string
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -37,7 +33,7 @@ kotlin.sourceSets {
|
|||||||
|
|
||||||
jsMain {
|
jsMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation(npm("binaryen", "98.0.0"))
|
implementation(npm("binaryen", "98.0.0-nightly.20201113"))
|
||||||
implementation(npm("js-base64", "3.6.0"))
|
implementation(npm("js-base64", "3.6.0"))
|
||||||
implementation(npm("webassembly", "0.11.0"))
|
implementation(npm("webassembly", "0.11.0"))
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,11 @@ import binaryen.Module as BinaryenModule
|
|||||||
private val spreader = eval("(obj, args) => obj(...args)")
|
private val spreader = eval("(obj, args) => obj(...args)")
|
||||||
|
|
||||||
@Suppress("UnsafeCastFromDynamic")
|
@Suppress("UnsafeCastFromDynamic")
|
||||||
internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: Algebra<T>) where T : Number {
|
internal sealed class WasmBuilder<T>(
|
||||||
|
val binaryenType: Type,
|
||||||
|
val kmathAlgebra: Algebra<T>,
|
||||||
|
val target: MST
|
||||||
|
) where T : Number {
|
||||||
val keys: MutableList<String> = mutableListOf()
|
val keys: MutableList<String> = mutableListOf()
|
||||||
lateinit var ctx: BinaryenModule
|
lateinit var ctx: BinaryenModule
|
||||||
|
|
||||||
@ -49,10 +53,10 @@ internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: A
|
|||||||
is MST.Binary -> visitBinary(mst)
|
is MST.Binary -> visitBinary(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun compile(mst: MST): Expression<T> {
|
val instance by lazy {
|
||||||
val bin = with(createModule()) {
|
val c = WasmModule(with(createModule()) {
|
||||||
ctx = this
|
ctx = this
|
||||||
val expr = visit(mst)
|
val expr = visit(target)
|
||||||
|
|
||||||
addFunction(
|
addFunction(
|
||||||
"executable",
|
"executable",
|
||||||
@ -68,20 +72,20 @@ internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: A
|
|||||||
val res = emitBinary()
|
val res = emitBinary()
|
||||||
dispose()
|
dispose()
|
||||||
res
|
res
|
||||||
}
|
})
|
||||||
|
|
||||||
val c = WasmModule(bin)
|
|
||||||
val i = Instance(c, js("{}") as Any)
|
val i = Instance(c, js("{}") as Any)
|
||||||
|
val symbols = keys.map(::StringSymbol)
|
||||||
return Expression { args ->
|
|
||||||
val params = keys.map(::StringSymbol).map(args::getValue).toTypedArray()
|
|
||||||
keys.clear()
|
keys.clear()
|
||||||
|
|
||||||
|
Expression<T> { args ->
|
||||||
|
val params = symbols.map(args::getValue).toTypedArray()
|
||||||
spreader(i.exports.asDynamic().executable, params) as T
|
spreader(i.exports.asDynamic().executable, params) as T
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) {
|
internal class RealWasmBuilder(target: MST) : WasmBuilder<Double>(f64, RealField, target) {
|
||||||
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions)
|
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions)
|
||||||
|
|
||||||
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
||||||
@ -117,7 +121,7 @@ internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class IntWasmBuilder : WasmBuilder<Int>(i32, IntRing) {
|
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, target) {
|
||||||
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
||||||
|
|
||||||
override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) {
|
override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) {
|
||||||
|
@ -1,28 +1,61 @@
|
|||||||
package kscience.kmath.ast
|
package kscience.kmath.ast
|
||||||
|
|
||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
|
import kotlin.random.Random
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.time.measureTime
|
||||||
|
|
||||||
internal class Test {
|
internal class Test {
|
||||||
@Test
|
@Test
|
||||||
fun int() {
|
fun int() {
|
||||||
val res = IntWasmBuilder().compile(MstRing { number(100000000) + number(10000000) })()
|
val res = IntWasmBuilder(MstRing { number(100000000) + number(10000000) }).instance()
|
||||||
assertEquals(110000000, res)
|
assertEquals(110000000, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun real() {
|
fun real() {
|
||||||
val res = RealWasmBuilder().compile(MstExtendedField { number(100000000) + number(2).pow(10) })()
|
val res = RealWasmBuilder(MstExtendedField { number(100000000) + number(2).pow(10) }).instance()
|
||||||
assertEquals(100001024.0, res)
|
assertEquals(100001024.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun argsPassing() {
|
fun argsPassing() {
|
||||||
val res = RealWasmBuilder()
|
val res = RealWasmBuilder(MstExtendedField { symbol("y") + symbol("x").pow(10) })
|
||||||
.compile(MstExtendedField { symbol("y") + symbol("x").pow(10) })("x" to 2.0, "y" to 100000000.0)
|
.instance("x" to 2.0, "y" to 100000000.0)
|
||||||
|
|
||||||
assertEquals(100001024.0, res)
|
assertEquals(100001024.0, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun powFunction() {
|
||||||
|
val expr = RealWasmBuilder(MstExtendedField { symbol("x").pow(1.0 / 6.0) }).instance
|
||||||
|
assertEquals(0.9730585187140817, expr("x" to 0.8488554755054833))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun manyRuns() {
|
||||||
|
println("Compiled")
|
||||||
|
val times = 1_000_000
|
||||||
|
var rng = Random(0)
|
||||||
|
var sum1 = 0.0
|
||||||
|
var sum2 = 0.0
|
||||||
|
|
||||||
|
measureTime {
|
||||||
|
val res = RealWasmBuilder(MstExtendedField { symbol("x").pow(1.0 / 6.0) }).instance
|
||||||
|
repeat(times) { sum1 += res("x" to rng.nextDouble()) }
|
||||||
|
}.also(::println)
|
||||||
|
|
||||||
|
println("MST")
|
||||||
|
rng = Random(0)
|
||||||
|
|
||||||
|
measureTime {
|
||||||
|
val res = RealField.mstInExtendedField { symbol("x").pow(1.0 / 6.0) }
|
||||||
|
repeat(times) { sum2 += res("x" to rng.nextDouble()) }
|
||||||
|
}.also(::println)
|
||||||
|
|
||||||
|
assertEquals(sum1, sum2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,7 +49,7 @@ public interface NumericAlgebra<T> : Algebra<T> {
|
|||||||
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
* Dynamic call of binary operation [operation] on [left] and [right] where right element is [Number].
|
||||||
*/
|
*/
|
||||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||||
leftSideNumberOperation(operation, right, left)
|
binaryOperation(operation, left, number(right))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user