Add support for Long and Int codegen

This commit is contained in:
Iaroslav Postovalov 2020-11-11 20:46:18 +07:00
parent 4fce91ae59
commit 4f22688f6f
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 134 additions and 65 deletions

File diff suppressed because one or more lines are too long

View File

@ -6,60 +6,55 @@ import kscience.kmath.operations.*
private val spreader = eval("(obj, args) => obj(...args)") private val spreader = eval("(obj, args) => obj(...args)")
public fun compileMstToWasmF64(mst: MST): Expression<Double> { internal sealed class WasmBuilder<T>(val binaryenType: binaryen.Type, val kmathAlgebra: Algebra<T>) where T : Number {
val keys = mutableListOf<String>() val keys: MutableList<String> = mutableListOf()
lateinit var ctx: binaryen.Module
val bin = with(binaryen.readBinary(INITIAL)) { open fun visitSymbolic(mst: MST.Symbolic): binaryen.ExpressionRef {
fun MST.visit(): binaryen.ExpressionRef = when (this) { try {
is MST.Symbolic -> { kmathAlgebra.symbol(mst.value)
var idx = keys.indexOf(value) } catch (ignored: Throwable) {
null
}?.let { return visitNumeric(MST.Numeric(it)) }
var idx = keys.indexOf(mst.value)
if (idx == -1) { if (idx == -1) {
keys += value keys += mst.value
idx = keys.lastIndex idx = keys.lastIndex
} }
local.get(idx, binaryen.f64) return ctx.local.get(idx, binaryenType)
} }
is MST.Numeric -> f64.const(value) abstract fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef
is MST.Unary -> when (operation) { open fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef =
SpaceOperations.MINUS_OPERATION -> f64.neg(value.visit()) error("Unary operation ${mst.operation} not defined in $this")
SpaceOperations.PLUS_OPERATION -> value.visit()
PowerOperations.SQRT_OPERATION -> f64.sqrt(value.visit()) open fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef =
TrigonometricOperations.SIN_OPERATION -> call("sin", arrayOf(value.visit()), binaryen.f64) error("Binary operation ${mst.operation} not defined in $this")
TrigonometricOperations.COS_OPERATION -> call("cos", arrayOf(value.visit()), binaryen.f64)
TrigonometricOperations.TAN_OPERATION -> call("tan", arrayOf(value.visit()), binaryen.f64) open fun createModule(): binaryen.Module = binaryen.Module()
TrigonometricOperations.ASIN_OPERATION -> call("asin", arrayOf(value.visit()), binaryen.f64)
TrigonometricOperations.ACOS_OPERATION -> call("acos", arrayOf(value.visit()), binaryen.f64) fun visit(mst: MST): binaryen.ExpressionRef = when (mst) {
TrigonometricOperations.ATAN_OPERATION -> call("atan", arrayOf(value.visit()), binaryen.f64) is MST.Symbolic -> visitSymbolic(mst)
HyperbolicOperations.SINH_OPERATION -> call("sinh", arrayOf(value.visit()), binaryen.f64) is MST.Numeric -> visitNumeric(mst)
HyperbolicOperations.COSH_OPERATION -> call("cosh", arrayOf(value.visit()), binaryen.f64) is MST.Unary -> visitUnary(mst)
HyperbolicOperations.TANH_OPERATION -> call("tanh", arrayOf(value.visit()), binaryen.f64) is MST.Binary -> visitBinary(mst)
HyperbolicOperations.ASINH_OPERATION -> call("asinh", arrayOf(value.visit()), binaryen.f64)
HyperbolicOperations.ACOSH_OPERATION -> call("acosh", arrayOf(value.visit()), binaryen.f64)
HyperbolicOperations.ATANH_OPERATION -> call("atanh", arrayOf(value.visit()), binaryen.f64)
ExponentialOperations.EXP_OPERATION -> call("exp", arrayOf(value.visit()), binaryen.f64)
ExponentialOperations.LN_OPERATION -> call("log", arrayOf(value.visit()), binaryen.f64)
else -> throw UnsupportedOperationException()
} }
is MST.Binary -> when (operation) { fun compile(mst: MST): Expression<T> {
SpaceOperations.PLUS_OPERATION -> f64.add(left.visit(), right.visit()) val keys = mutableListOf<String>()
RingOperations.TIMES_OPERATION -> f64.mul(left.visit(), right.visit())
FieldOperations.DIV_OPERATION -> f64.div(left.visit(), right.visit())
PowerOperations.POW_OPERATION -> call("pow", arrayOf(left.visit(), right.visit()), binaryen.f64)
else -> throw UnsupportedOperationException()
}
}
val expr = mst.visit() val bin = with(createModule()) {
ctx = this
val expr = visit(mst)
addFunction( addFunction(
"executable", "executable",
binaryen.createType(Array(keys.size) { binaryen.f64 }), binaryen.createType(Array(keys.size) { binaryenType }),
binaryen.f64, binaryenType,
arrayOf(), arrayOf(),
expr expr
) )
@ -77,6 +72,80 @@ public fun compileMstToWasmF64(mst: MST): Expression<Double> {
return Expression { args -> return Expression { args ->
val params = keys.map { StringSymbol(it) }.map { args.getValue(it) }.toTypedArray() val params = keys.map { StringSymbol(it) }.map { args.getValue(it) }.toTypedArray()
spreader(i.exports.asDynamic().executable, params) as Double keys.clear()
spreader(i.exports.asDynamic().executable, params) as T
}
}
}
internal class RealWasmBuilder : WasmBuilder<Double>(binaryen.f64, RealField) {
override fun createModule(): binaryen.Module = binaryen.readBinary(f64StandardFunctions)
override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.f64.const(mst.value)
override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) {
SpaceOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), binaryen.f64)
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), binaryen.f64)
TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), binaryen.f64)
TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), binaryen.f64)
TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), binaryen.f64)
TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), binaryen.f64)
HyperbolicOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), binaryen.f64)
HyperbolicOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), binaryen.f64)
HyperbolicOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), binaryen.f64)
HyperbolicOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), binaryen.f64)
HyperbolicOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), binaryen.f64)
HyperbolicOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), binaryen.f64)
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), binaryen.f64)
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), binaryen.f64)
else -> super.visitUnary(mst)
}
override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) {
SpaceOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
SpaceOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), binaryen.f64)
else -> super.visitBinary(mst)
}
}
internal class IntWasmBuilder : WasmBuilder<Int>(binaryen.i32, IntRing) {
override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.i32.const(mst.value)
override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) {
SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
else -> super.visitUnary(mst)
}
override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) {
SpaceOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
SpaceOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
FieldOperations.DIV_OPERATION -> ctx.i32.div_s(visit(mst.left), visit(mst.right))
else -> super.visitBinary(mst)
}
}
internal class LongWasmBuilder : WasmBuilder<Long>(binaryen.i64, LongRing) {
override fun visitNumeric(mst: MST.Numeric): binaryen.ExpressionRef = ctx.i64.const(mst.value)
override fun visitUnary(mst: MST.Unary): binaryen.ExpressionRef = when (mst.operation) {
SpaceOperations.MINUS_OPERATION -> ctx.i64.sub(ctx.i64.const(0, 0), visit(mst.value))
SpaceOperations.PLUS_OPERATION -> visit(mst.value)
else -> super.visitUnary(mst)
}
override fun visitBinary(mst: MST.Binary): binaryen.ExpressionRef = when (mst.operation) {
SpaceOperations.PLUS_OPERATION -> ctx.i64.add(visit(mst.left), visit(mst.right))
SpaceOperations.MINUS_OPERATION -> ctx.i64.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.i64.mul(visit(mst.left), visit(mst.right))
FieldOperations.DIV_OPERATION -> ctx.i64.div_s(visit(mst.left), visit(mst.right))
else -> super.visitBinary(mst)
} }
} }