diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/ast/RealWasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/ast/RealWasmBuilder.kt index 59871ce1d..9f01f53fa 100644 --- a/kmath-ast/src/jsMain/kotlin/kscience/kmath/ast/RealWasmBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/ast/RealWasmBuilder.kt @@ -1,16 +1,21 @@ package kscience.kmath.ast +import WebAssembly.Instance +import binaryen.* import kscience.kmath.expressions.Expression import kscience.kmath.expressions.StringSymbol import kscience.kmath.operations.* +import WebAssembly.Module as WasmModule +import binaryen.Module as BinaryenModule private val spreader = eval("(obj, args) => obj(...args)") -internal sealed class WasmBuilder(val binaryenType: binaryen.Type, val kmathAlgebra: Algebra) where T : Number { +@Suppress("UnsafeCastFromDynamic") +internal sealed class WasmBuilder(val binaryenType: Type, val kmathAlgebra: Algebra) where T : Number { val keys: MutableList = mutableListOf() - lateinit var ctx: binaryen.Module + lateinit var ctx: BinaryenModule - open fun visitSymbolic(mst: MST.Symbolic): binaryen.ExpressionRef { + open fun visitSymbolic(mst: MST.Symbolic): ExpressionRef { try { kmathAlgebra.symbol(mst.value) } catch (ignored: Throwable) { @@ -27,17 +32,17 @@ internal sealed class WasmBuilder(val binaryenType: binaryen.Type, val kmathA return ctx.local.get(idx, binaryenType) } - abstract fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef + abstract fun visitNumeric(mst: MST.Numeric): ExpressionRef - open fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = + open fun visitUnary(mst: MST.Unary): ExpressionRef = error("Unary operation ${mst.operation} not defined in $this") - open fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = + open fun visitBinary(mst: MST.Binary): ExpressionRef = error("Binary operation ${mst.operation} not defined in $this") - open fun createModule(): binaryen.Module = binaryen.Module() + open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") - fun visit(mst: MST): binaryen.ExpressionRef = when (mst) { + fun visit(mst: MST): ExpressionRef = when (mst) { is MST.Symbolic -> visitSymbolic(mst) is MST.Numeric -> visitNumeric(mst) is MST.Unary -> visitUnary(mst) @@ -45,21 +50,19 @@ internal sealed class WasmBuilder(val binaryenType: binaryen.Type, val kmathA } fun compile(mst: MST): Expression { - val keys = mutableListOf() - val bin = with(createModule()) { ctx = this val expr = visit(mst) addFunction( "executable", - binaryen.createType(Array(keys.size) { binaryenType }), + createType(Array(keys.size) { binaryenType }), binaryenType, arrayOf(), expr ) - binaryen.setOptimizeLevel(3) + setOptimizeLevel(3) // optimizeFunction("executable") addFunctionExport("executable", "executable") val res = emitBinary() @@ -67,8 +70,8 @@ internal sealed class WasmBuilder(val binaryenType: binaryen.Type, val kmathA res } - val c = WebAssembly.Module(bin) - val i = WebAssembly.Instance(c, js("{}") as Any) + val c = WasmModule(bin) + val i = Instance(c, js("{}") as Any) return Expression { args -> val params = keys.map { StringSymbol(it) }.map { args.getValue(it) }.toTypedArray() @@ -78,74 +81,55 @@ internal sealed class WasmBuilder(val binaryenType: binaryen.Type, val kmathA } } -internal class RealWasmBuilder : WasmBuilder(binaryen.f64, RealField) { - override fun createModule(): binaryen.Module = binaryen.readBinary(f64StandardFunctions) +internal class RealWasmBuilder : WasmBuilder(f64, RealField) { + override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions) - override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.f64.const(mst.value) + override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.f64.const(mst.value) - override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) { + override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) { SpaceOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) SpaceOperations.PLUS_OPERATION -> visit(mst.value) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value)) - TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), binaryen.f64) - TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), binaryen.f64) - TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), binaryen.f64) - TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), binaryen.f64) - TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), binaryen.f64) - TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), binaryen.f64) - HyperbolicOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), binaryen.f64) - HyperbolicOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), binaryen.f64) - HyperbolicOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), binaryen.f64) - HyperbolicOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), binaryen.f64) - HyperbolicOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), binaryen.f64) - HyperbolicOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), binaryen.f64) - ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), binaryen.f64) - ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), binaryen.f64) + TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) + TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) + TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), f64) + TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), f64) + TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), f64) + TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), f64) + HyperbolicOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), f64) + HyperbolicOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), f64) + HyperbolicOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), f64) + HyperbolicOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), f64) + HyperbolicOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), f64) + HyperbolicOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), f64) + ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), f64) + ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), f64) else -> super.visitUnary(mst) } - override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) { + override fun visitBinary(mst: MST.Binary): ExpressionRef = when (mst.operation) { SpaceOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) SpaceOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) - PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), binaryen.f64) + PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64) else -> super.visitBinary(mst) } } -internal class IntWasmBuilder : WasmBuilder(binaryen.i32, IntRing) { - override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.i32.const(mst.value) +internal class IntWasmBuilder : WasmBuilder(i32, IntRing) { + override fun visitNumeric(mst: MST.Numeric): ExpressionRef = ctx.i32.const(mst.value) - override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) { + override fun visitUnary(mst: MST.Unary): ExpressionRef = when (mst.operation) { SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) SpaceOperations.PLUS_OPERATION -> visit(mst.value) else -> super.visitUnary(mst) } - override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) { + override fun visitBinary(mst: MST.Binary): ExpressionRef = when (mst.operation) { SpaceOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) - FieldOperations.DIV_OPERATION -> ctx.i32.div_s(visit(mst.left), visit(mst.right)) - else -> super.visitBinary(mst) - } -} - -internal class LongWasmBuilder : WasmBuilder(binaryen.i64, LongRing) { - override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.i64.const(mst.value) - - override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) { - SpaceOperations.MINUS_OPERATION -> ctx.i64.sub(ctx.i64.const(0, 0), visit(mst.value)) - SpaceOperations.PLUS_OPERATION -> visit(mst.value) - else -> super.visitUnary(mst) - } - - override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) { - SpaceOperations.PLUS_OPERATION -> ctx.i64.add(visit(mst.left), visit(mst.right)) - SpaceOperations.MINUS_OPERATION -> ctx.i64.sub(visit(mst.left), visit(mst.right)) - RingOperations.TIMES_OPERATION -> ctx.i64.mul(visit(mst.left), visit(mst.right)) - FieldOperations.DIV_OPERATION -> ctx.i64.div_s(visit(mst.left), visit(mst.right)) else -> super.visitBinary(mst) } } diff --git a/kmath-ast/src/jsTest/kotlin/Test.kt b/kmath-ast/src/jsTest/kotlin/Test.kt index 357b1ac93..ee5751211 100644 --- a/kmath-ast/src/jsTest/kotlin/Test.kt +++ b/kmath-ast/src/jsTest/kotlin/Test.kt @@ -3,14 +3,26 @@ package kscience.kmath.ast import kscience.kmath.expressions.invoke import kscience.kmath.operations.invoke import kotlin.test.Test -import kotlin.time.measureTime +import kotlin.test.assertEquals internal class Test { @Test - fun c() { - measureTime { - val expr = compileMstToWasmF64(MstExtendedField { sin(symbol("x")) + cos(symbol("x")).pow(2) }) - println(expr("x" to 3.0)) - }.also { println(it) } + fun int() { + val res = IntWasmBuilder().compile(MstRing { number(100000000) + number(10000000) })() + assertEquals(110000000, res) + } + + @Test + fun real() { + val res = RealWasmBuilder().compile(MstExtendedField { number(100000000) + number(2).pow(10) })() + assertEquals(100001024.0, res) + } + + @Test + fun argsPassing() { + val res = RealWasmBuilder() + .compile(MstExtendedField { symbol("y") + symbol("x").pow(10) })("x" to 2.0, "y" to 100000000.0) + + assertEquals(100001024.0, res) } }