Very experimental WASM code generation by MST in contexts of Int and Real #158
@ -1,16 +1,21 @@
|
|||||||
package kscience.kmath.ast
|
package kscience.kmath.ast
|
||||||
|
|
||||||
|
import WebAssembly.Instance
|
||||||
|
import binaryen.*
|
||||||
import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
import kscience.kmath.expressions.StringSymbol
|
import kscience.kmath.expressions.StringSymbol
|
||||||
import kscience.kmath.operations.*
|
import kscience.kmath.operations.*
|
||||||
|
import WebAssembly.Module as WasmModule
|
||||||
|
import binaryen.Module as BinaryenModule
|
||||||
|
|
||||||
private val spreader = eval("(obj, args) => obj(...args)")
|
private val spreader = eval("(obj, args) => obj(...args)")
|
||||||
|
|
||||||
internal sealed class WasmBuilder<T>(val binaryenType: binaryen.Type, val kmathAlgebra: Algebra<T>) where T : Number {
|
@Suppress("UnsafeCastFromDynamic")
|
||||||
|
internal sealed class WasmBuilder<T>(val binaryenType: Type, val kmathAlgebra: Algebra<T>) where T : Number {
|
||||||
val keys: MutableList<String> = mutableListOf()
|
val keys: MutableList<String> = mutableListOf()
|
||||||
lateinit var ctx: binaryen.Module
|
lateinit var ctx: BinaryenModule
|
||||||
|
|
||||||
open fun visitSymbolic(mst: MST.Symbolic): binaryen.ExpressionRef {
|
open fun visitSymbolic(mst: MST.Symbolic): ExpressionRef {
|
||||||
try {
|
try {
|
||||||
kmathAlgebra.symbol(mst.value)
|
kmathAlgebra.symbol(mst.value)
|
||||||
} catch (ignored: Throwable) {
|
} catch (ignored: Throwable) {
|
||||||
@ -27,17 +32,17 @@ internal sealed class WasmBuilder<T>(val binaryenType: binaryen.Type, val kmathA
|
|||||||
return ctx.local.get(idx, binaryenType)
|
return ctx.local.get(idx, binaryenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef
|
abstract fun visitNumeric(mst: MST.Numeric): ExpressionRef
|
||||||
|
|
||||||
open fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef =
|
open fun visitUnary(mst: MST.Unary): ExpressionRef =
|
||||||
error("Unary operation ${mst.operation} not defined in $this")
|
error("Unary operation ${mst.operation} not defined in $this")
|
||||||
|
|
||||||
open fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef =
|
open fun visitBinary(mst: MST.Binary): ExpressionRef =
|
||||||
error("Binary operation ${mst.operation} not defined in $this")
|
error("Binary operation ${mst.operation} not defined in $this")
|
||||||
|
|
||||||
open fun createModule(): binaryen.Module = binaryen.Module()
|
open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
||||||
|
|
||||||
fun visit(mst: MST): binaryen.ExpressionRef = when (mst) {
|
fun visit(mst: MST): ExpressionRef = when (mst) {
|
||||||
is MST.Symbolic -> visitSymbolic(mst)
|
is MST.Symbolic -> visitSymbolic(mst)
|
||||||
is MST.Numeric -> visitNumeric(mst)
|
is MST.Numeric -> visitNumeric(mst)
|
||||||
is MST.Unary -> visitUnary(mst)
|
is MST.Unary -> visitUnary(mst)
|
||||||
@ -45,21 +50,19 @@ internal sealed class WasmBuilder<T>(val binaryenType: binaryen.Type, val kmathA
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun compile(mst: MST): Expression<T> {
|
fun compile(mst: MST): Expression<T> {
|
||||||
val keys = mutableListOf<String>()
|
|
||||||
|
|
||||||
val bin = with(createModule()) {
|
val bin = with(createModule()) {
|
||||||
ctx = this
|
ctx = this
|
||||||
val expr = visit(mst)
|
val expr = visit(mst)
|
||||||
|
|
||||||
addFunction(
|
addFunction(
|
||||||
"executable",
|
"executable",
|
||||||
binaryen.createType(Array(keys.size) { binaryenType }),
|
createType(Array(keys.size) { binaryenType }),
|
||||||
binaryenType,
|
binaryenType,
|
||||||
arrayOf(),
|
arrayOf(),
|
||||||
expr
|
expr
|
||||||
)
|
)
|
||||||
|
|
||||||
binaryen.setOptimizeLevel(3)
|
setOptimizeLevel(3)
|
||||||
// optimizeFunction("executable")
|
// optimizeFunction("executable")
|
||||||
addFunctionExport("executable", "executable")
|
addFunctionExport("executable", "executable")
|
||||||
val res = emitBinary()
|
val res = emitBinary()
|
||||||
@ -67,8 +70,8 @@ internal sealed class WasmBuilder<T>(val binaryenType: binaryen.Type, val kmathA
|
|||||||
res
|
res
|
||||||
}
|
}
|
||||||
|
|
||||||
val c = WebAssembly.Module(bin)
|
val c = WasmModule(bin)
|
||||||
val i = WebAssembly.Instance(c, js("{}") as Any)
|
val i = Instance(c, js("{}") as Any)
|
||||||
|
|
||||||
return Expression { args ->
|
return Expression { args ->
|
||||||
val params = keys.map { StringSymbol(it) }.map { args.getValue(it) }.toTypedArray()
|
val params = keys.map { StringSymbol(it) }.map { args.getValue(it) }.toTypedArray()
|
||||||
@ -78,74 +81,55 @@ internal sealed class WasmBuilder<T>(val binaryenType: binaryen.Type, val kmathA
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class RealWasmBuilder : WasmBuilder<Double>(binaryen.f64, RealField) {
|
internal class RealWasmBuilder : WasmBuilder<Double>(f64, RealField) {
|
||||||
override fun createModule(): binaryen.Module = binaryen.readBinary(f64StandardFunctions)
|
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions)
|
||||||
|
|
||||||
override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.f64.const(mst.value)
|
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
||||||
|
|
||||||
override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) {
|
override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) {
|
||||||
SpaceOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
SpaceOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
||||||
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
|
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
|
||||||
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
|
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
|
||||||
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), binaryen.f64)
|
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
|
||||||
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), binaryen.f64)
|
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
|
||||||
TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), binaryen.f64)
|
TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), f64)
|
||||||
TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), binaryen.f64)
|
TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), f64)
|
||||||
TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), binaryen.f64)
|
TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), f64)
|
||||||
TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), binaryen.f64)
|
TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), f64)
|
||||||
HyperbolicOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), binaryen.f64)
|
HyperbolicOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), f64)
|
||||||
HyperbolicOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), binaryen.f64)
|
HyperbolicOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), f64)
|
||||||
HyperbolicOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), binaryen.f64)
|
HyperbolicOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), f64)
|
||||||
HyperbolicOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), binaryen.f64)
|
HyperbolicOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), f64)
|
||||||
HyperbolicOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), binaryen.f64)
|
HyperbolicOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), f64)
|
||||||
HyperbolicOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), binaryen.f64)
|
HyperbolicOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), f64)
|
||||||
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), binaryen.f64)
|
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), f64)
|
||||||
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), binaryen.f64)
|
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), f64)
|
||||||
else -> super.visitUnary(mst)
|
else -> super.visitUnary(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: MST.Binary): ExpressionRef = when (mst.operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
SpaceOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
||||||
SpaceOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
SpaceOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
||||||
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
|
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
|
||||||
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), binaryen.f64)
|
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(mst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class IntWasmBuilder : WasmBuilder<Int>(binaryen.i32, IntRing) {
|
internal class IntWasmBuilder : WasmBuilder<Int>(i32, IntRing) {
|
||||||
override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.i32.const(mst.value)
|
override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
||||||
|
|
||||||
override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) {
|
override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) {
|
||||||
SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
||||||
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
|
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
|
||||||
else -> super.visitUnary(mst)
|
else -> super.visitUnary(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: MST.Binary): ExpressionRef = when (mst.operation) {
|
||||||
SpaceOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
SpaceOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
||||||
SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
||||||
FieldOperations.DIV_OPERATION -> ctx.i32.div_s(visit(mst.left), visit(mst.right))
|
|
||||||
else -> super.visitBinary(mst)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal class LongWasmBuilder : WasmBuilder<Long>(binaryen.i64, LongRing) {
|
|
||||||
override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.i64.const(mst.value)
|
|
||||||
|
|
||||||
override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) {
|
|
||||||
SpaceOperations.MINUS_OPERATION -> ctx.i64.sub(ctx.i64.const(0, 0), visit(mst.value))
|
|
||||||
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
|
|
||||||
else -> super.visitUnary(mst)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) {
|
|
||||||
SpaceOperations.PLUS_OPERATION -> ctx.i64.add(visit(mst.left), visit(mst.right))
|
|
||||||
SpaceOperations.MINUS_OPERATION -> ctx.i64.sub(visit(mst.left), visit(mst.right))
|
|
||||||
RingOperations.TIMES_OPERATION -> ctx.i64.mul(visit(mst.left), visit(mst.right))
|
|
||||||
FieldOperations.DIV_OPERATION -> ctx.i64.div_s(visit(mst.left), visit(mst.right))
|
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(mst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,14 +3,26 @@ package kscience.kmath.ast
|
|||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.time.measureTime
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class Test {
|
internal class Test {
|
||||||
@Test
|
@Test
|
||||||
fun c() {
|
fun int() {
|
||||||
measureTime {
|
val res = IntWasmBuilder().compile(MstRing { number(100000000) + number(10000000) })()
|
||||||
val expr = compileMstToWasmF64(MstExtendedField { sin(symbol("x")) + cos(symbol("x")).pow(2) })
|
assertEquals(110000000, res)
|
||||||
println(expr("x" to 3.0))
|
}
|
||||||
}.also { println(it) }
|
|
||||||
|
@Test
|
||||||
|
fun real() {
|
||||||
|
val res = RealWasmBuilder().compile(MstExtendedField { number(100000000) + number(2).pow(10) })()
|
||||||
|
assertEquals(100001024.0, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun argsPassing() {
|
||||||
|
val res = RealWasmBuilder()
|
||||||
|
.compile(MstExtendedField { symbol("y") + symbol("x").pow(10) })("x" to 2.0, "y" to 100000000.0)
|
||||||
|
|
||||||
|
assertEquals(100001024.0, res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user