Completely specialized expression types for Int
, Long
, Double
and compilation of MST to it
#444
@ -38,6 +38,22 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
@Benchmark
|
@Benchmark
|
||||||
fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole)
|
fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Benchmark case for [Expression] created with [compileToExpression].
|
||||||
|
*/
|
||||||
|
@Benchmark
|
||||||
|
fun asmPrimitiveExpressionArray(blackhole: Blackhole) {
|
||||||
|
val random = Random(0)
|
||||||
|
var sum = 0.0
|
||||||
|
val m = DoubleArray(1)
|
||||||
|
|
||||||
|
repeat(times) {
|
||||||
|
m[xIdx] = random.nextDouble()
|
||||||
|
sum += asmPrimitive(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
blackhole.consume(sum)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Benchmark case for [Expression] created with [compileToExpression].
|
* Benchmark case for [Expression] created with [compileToExpression].
|
||||||
@ -82,7 +98,6 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
|
|
||||||
private companion object {
|
private companion object {
|
||||||
private val x by symbol
|
private val x by symbol
|
||||||
private val algebra = DoubleField
|
|
||||||
private const val times = 1_000_000
|
private const val times = 1_000_000
|
||||||
|
|
||||||
private val functional = DoubleField.expression {
|
private val functional = DoubleField.expression {
|
||||||
@ -95,7 +110,10 @@ internal class ExpressionsInterpretersBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private val mst = node.toExpression(DoubleField)
|
private val mst = node.toExpression(DoubleField)
|
||||||
|
|
||||||
private val asmPrimitive = node.compileToExpression(DoubleField)
|
private val asmPrimitive = node.compileToExpression(DoubleField)
|
||||||
|
private val xIdx = asmPrimitive.indexer.indexOf(x)
|
||||||
|
|
||||||
private val asmGeneric = node.compileToExpression(DoubleField as Algebra<Double>)
|
private val asmGeneric = node.compileToExpression(DoubleField as Algebra<Double>)
|
||||||
|
|
||||||
private val raw = Expression<Double> { args ->
|
private val raw = Expression<Double> { args ->
|
||||||
|
@ -7,7 +7,6 @@ package space.kscience.kmath.ast
|
|||||||
|
|
||||||
import space.kscience.kmath.asm.compileToExpression
|
import space.kscience.kmath.asm.compileToExpression
|
||||||
import space.kscience.kmath.expressions.MstExtendedField
|
import space.kscience.kmath.expressions.MstExtendedField
|
||||||
import space.kscience.kmath.expressions.Symbol
|
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
@ -17,10 +16,11 @@ fun main() {
|
|||||||
x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x)
|
x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x)
|
||||||
}.compileToExpression(DoubleField)
|
}.compileToExpression(DoubleField)
|
||||||
|
|
||||||
val m = HashMap<Symbol, Double>()
|
val m = DoubleArray(expr.indexer.symbols.size)
|
||||||
|
val xIdx = expr.indexer.indexOf(x)
|
||||||
|
|
||||||
repeat(10000000) {
|
repeat(10000000) {
|
||||||
m[x] = 1.0
|
m[xIdx] = 1.0
|
||||||
expr(m)
|
expr(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,9 @@ package space.kscience.kmath.ast
|
|||||||
|
|
||||||
import space.kscience.kmath.expressions.MstRing
|
import space.kscience.kmath.expressions.MstRing
|
||||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||||
|
import space.kscience.kmath.expressions.Symbol.Companion.y
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.IntRing
|
import space.kscience.kmath.operations.IntRing
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -16,11 +18,23 @@ import kotlin.test.assertFailsWith
|
|||||||
|
|
||||||
internal class TestCompilerVariables {
|
internal class TestCompilerVariables {
|
||||||
@Test
|
@Test
|
||||||
fun testVariable() = runCompilerTest {
|
fun testNoVariables() = runCompilerTest {
|
||||||
|
val expr = "0".parseMath().compileToExpression(IntRing)
|
||||||
|
assertEquals(0, expr())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testOneVariable() = runCompilerTest {
|
||||||
val expr = MstRing { x }.compileToExpression(IntRing)
|
val expr = MstRing { x }.compileToExpression(IntRing)
|
||||||
assertEquals(1, expr(x to 1))
|
assertEquals(1, expr(x to 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTwoVariables() = runCompilerTest {
|
||||||
|
val expr = "y+x/y+x".parseMath().compileToExpression(DoubleField)
|
||||||
|
assertEquals(8.0, expr(x to 4.0, y to 2.0))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testUndefinedVariableFails() = runCompilerTest {
|
fun testUndefinedVariableFails() = runCompilerTest {
|
||||||
val expr = MstRing { x }.compileToExpression(IntRing)
|
val expr = MstRing { x }.compileToExpression(IntRing)
|
||||||
|
@ -201,8 +201,8 @@ internal open external class Module {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@JsName("Instance")
|
@JsName("Instance")
|
||||||
internal open external class Instance(module: Module, importObject: Any = definedExternally) {
|
internal open external class Instance(module: Module, importObject: dynamic = definedExternally) {
|
||||||
open var exports: Any
|
open var exports: dynamic
|
||||||
}
|
}
|
||||||
|
|
||||||
@JsName("Memory")
|
@JsName("Memory")
|
||||||
|
@ -5,12 +5,11 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.wasm.internal
|
package space.kscience.kmath.wasm.internal
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST
|
|
||||||
import space.kscience.kmath.expressions.MST.*
|
import space.kscience.kmath.expressions.MST.*
|
||||||
import space.kscience.kmath.expressions.Symbol
|
|
||||||
import space.kscience.kmath.internal.binaryen.*
|
import space.kscience.kmath.internal.binaryen.*
|
||||||
import space.kscience.kmath.internal.webassembly.Instance
|
import space.kscience.kmath.internal.webassembly.Instance
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.internal.binaryen.Module as BinaryenModule
|
import space.kscience.kmath.internal.binaryen.Module as BinaryenModule
|
||||||
import space.kscience.kmath.internal.webassembly.Module as WasmModule
|
import space.kscience.kmath.internal.webassembly.Module as WasmModule
|
||||||
@ -18,65 +17,17 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
|
|||||||
private val spreader = eval("(obj, args) => obj(...args)")
|
private val spreader = eval("(obj, args) => obj(...args)")
|
||||||
|
|
||||||
@Suppress("UnsafeCastFromDynamic")
|
@Suppress("UnsafeCastFromDynamic")
|
||||||
internal sealed class WasmBuilder<T : Number>(
|
internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
||||||
protected val binaryenType: Type,
|
protected val binaryenType: Type,
|
||||||
protected val algebra: Algebra<T>,
|
protected val algebra: Algebra<T>,
|
||||||
protected val target: MST,
|
protected val target: MST,
|
||||||
) {
|
) {
|
||||||
protected val keys: MutableList<Symbol> = mutableListOf()
|
protected val keys: MutableList<Symbol> = mutableListOf()
|
||||||
lateinit var ctx: BinaryenModule
|
protected lateinit var ctx: BinaryenModule
|
||||||
|
|
||||||
open fun visitSymbolic(mst: Symbol): ExpressionRef {
|
abstract val instance: E
|
||||||
algebra.bindSymbolOrNull(mst)?.let { return visitNumeric(Numeric(it)) }
|
|
||||||
|
|
||||||
var idx = keys.indexOf(mst)
|
protected val executable = run {
|
||||||
|
|
||||||
if (idx == -1) {
|
|
||||||
keys += mst
|
|
||||||
idx = keys.lastIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
return ctx.local.get(idx, binaryenType)
|
|
||||||
}
|
|
||||||
|
|
||||||
abstract fun visitNumeric(mst: Numeric): ExpressionRef
|
|
||||||
|
|
||||||
protected open fun visitUnary(mst: Unary): ExpressionRef =
|
|
||||||
error("Unary operation ${mst.operation} not defined in $this")
|
|
||||||
|
|
||||||
protected open fun visitBinary(mst: Binary): ExpressionRef =
|
|
||||||
error("Binary operation ${mst.operation} not defined in $this")
|
|
||||||
|
|
||||||
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
|
||||||
|
|
||||||
protected fun visit(mst: MST): ExpressionRef = when (mst) {
|
|
||||||
is Symbol -> visitSymbolic(mst)
|
|
||||||
is Numeric -> visitNumeric(mst)
|
|
||||||
|
|
||||||
is Unary -> when {
|
|
||||||
algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric(
|
|
||||||
Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value)))
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> visitUnary(mst)
|
|
||||||
}
|
|
||||||
|
|
||||||
is Binary -> when {
|
|
||||||
algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(
|
|
||||||
Numeric(
|
|
||||||
algebra.binaryOperationFunction(mst.operation)
|
|
||||||
.invoke(
|
|
||||||
algebra.number((mst.left as Numeric).value),
|
|
||||||
algebra.number((mst.right as Numeric).value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else -> visitBinary(mst)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val instance by lazy {
|
|
||||||
val c = WasmModule(with(createModule()) {
|
val c = WasmModule(with(createModule()) {
|
||||||
ctx = this
|
ctx = this
|
||||||
val expr = visit(target)
|
val expr = visit(target)
|
||||||
@ -97,41 +48,93 @@ internal sealed class WasmBuilder<T : Number>(
|
|||||||
res
|
res
|
||||||
})
|
})
|
||||||
|
|
||||||
val i = Instance(c, js("{}") as Any)
|
Instance(c, js("{}")).exports.executable
|
||||||
val symbols = keys
|
}
|
||||||
keys.clear()
|
|
||||||
|
|
||||||
Expression<T> { args ->
|
protected open fun visitSymbol(node: Symbol): ExpressionRef {
|
||||||
val params = symbols.map(args::getValue).toTypedArray()
|
algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
|
||||||
spreader(i.exports.asDynamic().executable, params) as T
|
|
||||||
|
var idx = keys.indexOf(node)
|
||||||
|
|
||||||
|
if (idx == -1) {
|
||||||
|
keys += node
|
||||||
|
idx = keys.lastIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx.local.get(idx, binaryenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract fun visitNumeric(node: Numeric): ExpressionRef
|
||||||
|
|
||||||
|
protected open fun visitUnary(node: Unary): ExpressionRef =
|
||||||
|
error("Unary operation ${node.operation} not defined in $this")
|
||||||
|
|
||||||
|
protected open fun visitBinary(mst: Binary): ExpressionRef =
|
||||||
|
error("Binary operation ${mst.operation} not defined in $this")
|
||||||
|
|
||||||
|
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
||||||
|
|
||||||
|
protected fun visit(node: MST): ExpressionRef = when (node) {
|
||||||
|
is Symbol -> visitSymbol(node)
|
||||||
|
is Numeric -> visitNumeric(node)
|
||||||
|
|
||||||
|
is Unary -> when {
|
||||||
|
algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
|
||||||
|
Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> visitUnary(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
is Binary -> when {
|
||||||
|
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
|
||||||
|
Numeric(
|
||||||
|
algebra.binaryOperationFunction(node.operation)
|
||||||
|
.invoke(
|
||||||
|
algebra.number((node.left as Numeric).value),
|
||||||
|
algebra.number((node.right as Numeric).value)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else -> visitBinary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleField, target) {
|
@UnstableKMathAPI
|
||||||
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions)
|
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
|
||||||
|
override val instance by lazy {
|
||||||
|
object : DoubleExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(keys)
|
||||||
|
|
||||||
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
override fun invoke(arguments: DoubleArray) = spreader(executable, arguments).unsafeCast<Double>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
override fun createModule() = readBinary(f64StandardFunctions)
|
||||||
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
|
||||||
GroupOps.PLUS_OPERATION -> visit(mst.value)
|
override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value)
|
||||||
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
|
|
||||||
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
|
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
|
||||||
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
|
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
|
||||||
TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), f64)
|
GroupOps.PLUS_OPERATION -> visit(node.value)
|
||||||
TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), f64)
|
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
|
||||||
TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(node.value)), f64)
|
||||||
TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), f64)
|
ExponentialOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), f64)
|
ExponentialOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), f64)
|
ExponentialOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(node.value)), f64)
|
||||||
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), f64)
|
ExponentialOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(node.value)), f64)
|
||||||
else -> super.visitUnary(mst)
|
ExponentialOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(node.value)), f64)
|
||||||
|
ExponentialOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(node.value)), f64)
|
||||||
|
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(node.value)), f64)
|
||||||
|
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(node.value)), f64)
|
||||||
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
||||||
@ -144,13 +147,22 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, target) {
|
@UnstableKMathAPI
|
||||||
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
|
||||||
|
override val instance by lazy {
|
||||||
|
object : IntExpression {
|
||||||
|
override val indexer = SimpleSymbolIndexer(keys)
|
||||||
|
|
||||||
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
override fun invoke(arguments: IntArray) = spreader(executable, arguments).unsafeCast<Int>()
|
||||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
}
|
||||||
GroupOps.PLUS_OPERATION -> visit(mst.value)
|
}
|
||||||
else -> super.visitUnary(mst)
|
|
||||||
|
override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value)
|
||||||
|
|
||||||
|
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
|
||||||
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
|
||||||
|
GroupOps.PLUS_OPERATION -> visit(node.value)
|
||||||
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
||||||
|
@ -3,13 +3,12 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:Suppress("UNUSED_PARAMETER")
|
||||||
|
|
||||||
package space.kscience.kmath.wasm
|
package space.kscience.kmath.wasm
|
||||||
|
|
||||||
import space.kscience.kmath.estree.compileWith
|
import space.kscience.kmath.estree.compileWith
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST
|
|
||||||
import space.kscience.kmath.expressions.Symbol
|
|
||||||
import space.kscience.kmath.expressions.invoke
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.IntRing
|
import space.kscience.kmath.operations.IntRing
|
||||||
@ -22,7 +21,7 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = compileWith(algebra)
|
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -50,7 +49,7 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
|||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = compileWith(algebra)
|
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleWasmBuilder(this).instance
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -8,11 +8,8 @@
|
|||||||
package space.kscience.kmath.asm
|
package space.kscience.kmath.asm
|
||||||
|
|
||||||
import space.kscience.kmath.asm.internal.*
|
import space.kscience.kmath.asm.internal.*
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST
|
|
||||||
import space.kscience.kmath.expressions.MST.*
|
import space.kscience.kmath.expressions.MST.*
|
||||||
import space.kscience.kmath.expressions.Symbol
|
|
||||||
import space.kscience.kmath.expressions.invoke
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
|
|
||||||
@ -48,7 +45,13 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
|||||||
loadVariable(node.identity)
|
loadVariable(node.identity)
|
||||||
}
|
}
|
||||||
|
|
||||||
is Numeric -> loadNumberConstant(node.value)
|
is Numeric -> if (algebra is NumericAlgebra) {
|
||||||
|
if (Number::class.java.isAssignableFrom(type))
|
||||||
|
loadNumberConstant(algebra.number(node.value) as Number)
|
||||||
|
else
|
||||||
|
loadObjectConstant(algebra.number(node.value))
|
||||||
|
} else
|
||||||
|
error("Numeric nodes are not supported by $this")
|
||||||
|
|
||||||
is Unary -> when {
|
is Unary -> when {
|
||||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||||
@ -121,13 +124,15 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg argu
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = IntAsmBuilder(this).instance
|
@UnstableKMathAPI
|
||||||
|
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra).invoke(arguments)
|
||||||
|
|
||||||
@ -136,6 +141,7 @@ public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
|
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
|
||||||
compileToExpression(algebra)(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
|
||||||
@ -145,7 +151,8 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public fun MST.compileToExpression(algebra: LongRing): Expression<Long> = LongAsmBuilder(this).instance
|
@UnstableKMathAPI
|
||||||
|
public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -153,6 +160,7 @@ public fun MST.compileToExpression(algebra: LongRing): Expression<Long> = LongAs
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra).invoke(arguments)
|
||||||
|
|
||||||
@ -162,6 +170,7 @@ public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>): Long =
|
public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>): Long =
|
||||||
compileToExpression(algebra)(*arguments)
|
compileToExpression(algebra)(*arguments)
|
||||||
|
|
||||||
@ -171,13 +180,15 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>):
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleAsmBuilder(this).instance
|
@UnstableKMathAPI
|
||||||
|
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compile given MST to expression and evaluate it against [arguments].
|
* Compile given MST to expression and evaluate it against [arguments].
|
||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(arguments)
|
compileToExpression(algebra).invoke(arguments)
|
||||||
|
|
||||||
@ -186,5 +197,6 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
|
||||||
compileToExpression(algebra).invoke(*arguments)
|
compileToExpression(algebra).invoke(*arguments)
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.asm.internal
|
package space.kscience.kmath.asm.internal
|
||||||
|
|
||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
|
import org.objectweb.asm.Type.getObjectType
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
|
||||||
internal abstract class AsmBuilder {
|
internal abstract class AsmBuilder {
|
||||||
@ -22,31 +23,31 @@ internal abstract class AsmBuilder {
|
|||||||
/**
|
/**
|
||||||
* ASM type for [Expression].
|
* ASM type for [Expression].
|
||||||
*/
|
*/
|
||||||
val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Expression") }
|
val EXPRESSION_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Expression")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.util.Map].
|
* ASM type for [java.util.Map].
|
||||||
*/
|
*/
|
||||||
val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") }
|
val MAP_TYPE: Type = getObjectType("java/util/Map")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Object].
|
* ASM type for [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") }
|
val OBJECT_TYPE: Type = getObjectType("java/lang/Object")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.String].
|
* ASM type for [java.lang.String].
|
||||||
*/
|
*/
|
||||||
val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") }
|
val STRING_TYPE: Type = getObjectType("java/lang/String")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for MapIntrinsics.
|
* ASM type for MapIntrinsics.
|
||||||
*/
|
*/
|
||||||
val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") }
|
val MAP_INTRINSICS_TYPE: Type = getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM Type for [space.kscience.kmath.expressions.Symbol].
|
* ASM Type for [space.kscience.kmath.expressions.Symbol].
|
||||||
*/
|
*/
|
||||||
val SYMBOL_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Symbol") }
|
val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.asm.internal
|
|
@ -78,7 +78,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
)
|
)
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL,
|
ACC_PUBLIC,
|
||||||
"invoke",
|
"invoke",
|
||||||
getMethodDescriptor(tType, MAP_TYPE),
|
getMethodDescriptor(tType, MAP_TYPE),
|
||||||
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||||
@ -116,7 +116,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
ACC_PUBLIC or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||||
"invoke",
|
"invoke",
|
||||||
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||||
null,
|
null,
|
||||||
@ -156,7 +156,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
)
|
)
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
ACC_PUBLIC,
|
ACC_PUBLIC or ACC_SYNTHETIC,
|
||||||
"<init>",
|
"<init>",
|
||||||
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
|
||||||
null,
|
null,
|
||||||
@ -176,7 +176,7 @@ internal class GenericAsmBuilder<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
label()
|
label()
|
||||||
visitInsn(RETURN)
|
areturn(VOID_TYPE)
|
||||||
val l4 = label()
|
val l4 = label()
|
||||||
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
|
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
|
||||||
|
|
||||||
@ -209,10 +209,10 @@ internal class GenericAsmBuilder<T>(
|
|||||||
*/
|
*/
|
||||||
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
|
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
|
||||||
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
|
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
|
||||||
invokeMethodVisitor.load(0, classType)
|
load(0, classType)
|
||||||
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
|
||||||
iconst(idx)
|
iconst(idx)
|
||||||
visitInsn(AALOAD)
|
aload(OBJECT_TYPE)
|
||||||
if (type != OBJECT_TYPE) checkcast(type)
|
if (type != OBJECT_TYPE) checkcast(type)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,6 +320,6 @@ internal class GenericAsmBuilder<T>(
|
|||||||
/**
|
/**
|
||||||
* ASM type for array of [java.lang.Object].
|
* ASM type for array of [java.lang.Object].
|
||||||
*/
|
*/
|
||||||
val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") }
|
val OBJECT_ARRAY_TYPE: Type = getType("[Ljava/lang/Object;")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,38 +6,49 @@
|
|||||||
package space.kscience.kmath.asm.internal
|
package space.kscience.kmath.asm.internal
|
||||||
|
|
||||||
import org.objectweb.asm.ClassWriter
|
import org.objectweb.asm.ClassWriter
|
||||||
import org.objectweb.asm.Opcodes
|
import org.objectweb.asm.FieldVisitor
|
||||||
|
import org.objectweb.asm.Opcodes.*
|
||||||
import org.objectweb.asm.Type
|
import org.objectweb.asm.Type
|
||||||
import org.objectweb.asm.Type.*
|
import org.objectweb.asm.Type.*
|
||||||
import org.objectweb.asm.commons.InstructionAdapter
|
import org.objectweb.asm.commons.InstructionAdapter
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.expressions.MST
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.expressions.Symbol
|
|
||||||
import space.kscience.kmath.expressions.invoke
|
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import java.lang.invoke.MethodHandles
|
import java.lang.invoke.MethodHandles
|
||||||
import java.lang.invoke.MethodType
|
import java.lang.invoke.MethodType
|
||||||
import java.nio.file.Paths
|
import java.nio.file.Paths
|
||||||
import kotlin.io.path.writeBytes
|
import kotlin.io.path.writeBytes
|
||||||
|
|
||||||
internal sealed class PrimitiveAsmBuilder<T : Number>(
|
@UnstableKMathAPI
|
||||||
protected val algebra: Algebra<T>,
|
internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
|
||||||
|
protected val algebra: NumericAlgebra<T>,
|
||||||
classOfT: Class<*>,
|
classOfT: Class<*>,
|
||||||
protected val classOfTPrimitive: Class<*>,
|
protected val classOfTPrimitive: Class<*>,
|
||||||
|
expressionParent: Class<E>,
|
||||||
protected val target: MST,
|
protected val target: MST,
|
||||||
) : AsmBuilder() {
|
) : AsmBuilder() {
|
||||||
private val className: String = buildName(target)
|
private val className: String = buildName(target)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T].
|
* ASM type for [tType].
|
||||||
*/
|
*/
|
||||||
private val tType: Type = classOfT.asm
|
private val tType: Type = classOfT.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for [T].
|
* ASM type for [classOfTPrimitive].
|
||||||
*/
|
*/
|
||||||
protected val tTypePrimitive: Type = classOfTPrimitive.asm
|
protected val tTypePrimitive: Type = classOfTPrimitive.asm
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for array of [classOfTPrimitive].
|
||||||
|
*/
|
||||||
|
protected val tTypePrimitiveArray: Type = getType("[" + classOfTPrimitive.asm.descriptor)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for expression parent.
|
||||||
|
*/
|
||||||
|
private val expressionParentType = expressionParent.asm
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ASM type for new class.
|
* ASM type for new class.
|
||||||
*/
|
*/
|
||||||
@ -49,58 +60,91 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
|
|||||||
protected lateinit var invokeMethodVisitor: InstructionAdapter
|
protected lateinit var invokeMethodVisitor: InstructionAdapter
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Local variables indices are indices of symbols in this list.
|
* Indexer for arguments in [target].
|
||||||
*/
|
*/
|
||||||
private val argumentsLocals = mutableListOf<String>()
|
private val argumentsIndexer = mutableListOf<Symbol>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subclasses, loads and instantiates [Expression] for given parameters.
|
* Subclasses, loads and instantiates [Expression] for given parameters.
|
||||||
*
|
*
|
||||||
* The built instance is cached.
|
* The built instance is cached.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE")
|
@Suppress("UNCHECKED_CAST")
|
||||||
val instance: Expression<T> by lazy {
|
val instance: E by lazy {
|
||||||
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
|
||||||
visit(
|
visit(
|
||||||
Opcodes.V1_8,
|
V1_8,
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER,
|
ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
|
||||||
classType.internalName,
|
classType.internalName,
|
||||||
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;",
|
"${OBJECT_TYPE.descriptor}${expressionParentType.descriptor}",
|
||||||
OBJECT_TYPE.internalName,
|
OBJECT_TYPE.internalName,
|
||||||
arrayOf(EXPRESSION_TYPE.internalName),
|
arrayOf(expressionParentType.internalName),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
visitField(
|
||||||
|
access = ACC_PRIVATE or ACC_FINAL,
|
||||||
|
name = "indexer",
|
||||||
|
descriptor = SYMBOL_INDEXER_TYPE.descriptor,
|
||||||
|
signature = null,
|
||||||
|
value = null,
|
||||||
|
block = FieldVisitor::visitEnd,
|
||||||
|
)
|
||||||
visitMethod(
|
visitMethod(
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL,
|
ACC_PUBLIC,
|
||||||
"invoke",
|
"getIndexer",
|
||||||
getMethodDescriptor(tType, MAP_TYPE),
|
getMethodDescriptor(SYMBOL_INDEXER_TYPE),
|
||||||
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
null,
|
||||||
null,
|
null,
|
||||||
).instructionAdapter {
|
).instructionAdapter {
|
||||||
invokeMethodVisitor = this
|
|
||||||
visitCode()
|
visitCode()
|
||||||
val preparingVariables = label()
|
val start = label()
|
||||||
visitVariables(target)
|
load(0, classType)
|
||||||
val expressionResult = label()
|
getfield(classType.internalName, "indexer", SYMBOL_INDEXER_TYPE.descriptor)
|
||||||
visitExpression(target)
|
areturn(SYMBOL_INDEXER_TYPE)
|
||||||
box()
|
|
||||||
areturn(tType)
|
|
||||||
val end = label()
|
val end = label()
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"this",
|
"this",
|
||||||
classType.descriptor,
|
classType.descriptor,
|
||||||
null,
|
null,
|
||||||
preparingVariables,
|
start,
|
||||||
|
end,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMaxs(0, 0)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC,
|
||||||
|
"invoke",
|
||||||
|
getMethodDescriptor(tTypePrimitive, tTypePrimitiveArray),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
).instructionAdapter {
|
||||||
|
invokeMethodVisitor = this
|
||||||
|
visitCode()
|
||||||
|
val start = label()
|
||||||
|
visitVariables(target, arrayMode = true)
|
||||||
|
visitExpression(target)
|
||||||
|
areturn(tTypePrimitive)
|
||||||
|
val end = label()
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"this",
|
||||||
|
classType.descriptor,
|
||||||
|
null,
|
||||||
|
start,
|
||||||
end,
|
end,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
visitLocalVariable(
|
visitLocalVariable(
|
||||||
"arguments",
|
"arguments",
|
||||||
MAP_TYPE.descriptor,
|
tTypePrimitiveArray.descriptor,
|
||||||
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
|
null,
|
||||||
preparingVariables,
|
start,
|
||||||
end,
|
end,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
@ -110,7 +154,45 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC,
|
ACC_PUBLIC or ACC_FINAL,
|
||||||
|
"invoke",
|
||||||
|
getMethodDescriptor(tType, MAP_TYPE),
|
||||||
|
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
|
||||||
|
null,
|
||||||
|
).instructionAdapter {
|
||||||
|
invokeMethodVisitor = this
|
||||||
|
visitCode()
|
||||||
|
val start = label()
|
||||||
|
visitVariables(target, arrayMode = false)
|
||||||
|
visitExpression(target)
|
||||||
|
box()
|
||||||
|
areturn(tType)
|
||||||
|
val end = label()
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"this",
|
||||||
|
classType.descriptor,
|
||||||
|
null,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
visitLocalVariable(
|
||||||
|
"arguments",
|
||||||
|
MAP_TYPE.descriptor,
|
||||||
|
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
visitMaxs(0, 0)
|
||||||
|
visitEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
visitMethod(
|
||||||
|
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
|
||||||
"invoke",
|
"invoke",
|
||||||
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
|
||||||
null,
|
null,
|
||||||
@ -138,21 +220,22 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
visitMethod(
|
visitMethod(
|
||||||
Opcodes.ACC_PUBLIC,
|
ACC_PUBLIC or ACC_SYNTHETIC,
|
||||||
"<init>",
|
"<init>",
|
||||||
getMethodDescriptor(VOID_TYPE),
|
getMethodDescriptor(VOID_TYPE, SYMBOL_INDEXER_TYPE),
|
||||||
null,
|
null,
|
||||||
null,
|
null,
|
||||||
).instructionAdapter {
|
).instructionAdapter {
|
||||||
val start = label()
|
val start = label()
|
||||||
load(0, classType)
|
load(0, classType)
|
||||||
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
|
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
|
||||||
label()
|
|
||||||
load(0, classType)
|
load(0, classType)
|
||||||
label()
|
load(1, SYMBOL_INDEXER_TYPE)
|
||||||
visitInsn(Opcodes.RETURN)
|
putfield(classType.internalName, "indexer", SYMBOL_INDEXER_TYPE.descriptor)
|
||||||
|
areturn(VOID_TYPE)
|
||||||
val end = label()
|
val end = label()
|
||||||
visitLocalVariable("this", classType.descriptor, null, start, end, 0)
|
visitLocalVariable("this", classType.descriptor, null, start, end, 0)
|
||||||
|
visitLocalVariable("indexer", SYMBOL_INDEXER_TYPE.descriptor, null, start, end, 1)
|
||||||
visitMaxs(0, 0)
|
visitMaxs(0, 0)
|
||||||
visitEnd()
|
visitEnd()
|
||||||
}
|
}
|
||||||
@ -166,14 +249,16 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
|
|||||||
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
||||||
Paths.get("${className.split('.').last()}.class").writeBytes(binary)
|
Paths.get("${className.split('.').last()}.class").writeBytes(binary)
|
||||||
|
|
||||||
MethodHandles.publicLookup().findConstructor(cls, MethodType.methodType(Void.TYPE))() as Expression<T>
|
MethodHandles
|
||||||
|
.publicLookup()
|
||||||
|
.findConstructor(cls, MethodType.methodType(Void.TYPE, SymbolIndexer::class.java))
|
||||||
|
.invoke(SimpleSymbolIndexer(argumentsIndexer)) as E
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive
|
* Loads a numeric constant [value] from the class's constants.
|
||||||
* constant from the constant pool.
|
|
||||||
*/
|
*/
|
||||||
fun loadNumberConstant(value: Number) {
|
protected fun loadNumberConstant(value: Number) {
|
||||||
when (tTypePrimitive) {
|
when (tTypePrimitive) {
|
||||||
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
|
||||||
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
|
||||||
@ -185,38 +270,50 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using
|
* Stores value variable [name] into a local. Should be called before using [loadVariable]. Should be called only
|
||||||
* [loadVariable].
|
* once for a variable.
|
||||||
*/
|
*/
|
||||||
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run {
|
protected fun prepareVariable(name: Symbol, arrayMode: Boolean): Unit = invokeMethodVisitor.run {
|
||||||
if (name in argumentsLocals) return@run
|
var argumentIndex = argumentsIndexer.indexOf(name)
|
||||||
load(1, MAP_TYPE)
|
|
||||||
aconst(name)
|
|
||||||
|
|
||||||
invokestatic(
|
if (argumentIndex == -1) {
|
||||||
MAP_INTRINSICS_TYPE.internalName,
|
argumentsIndexer += name
|
||||||
"getOrFail",
|
argumentIndex = argumentsIndexer.lastIndex
|
||||||
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
|
|
||||||
checkcast(tType)
|
|
||||||
var idx = argumentsLocals.indexOf(name)
|
|
||||||
|
|
||||||
if (idx == -1) {
|
|
||||||
argumentsLocals += name
|
|
||||||
idx = argumentsLocals.lastIndex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unbox()
|
val localIndex = 2 + argumentIndex * tTypePrimitive.size
|
||||||
store(2 + idx, tTypePrimitive)
|
|
||||||
|
if (arrayMode) {
|
||||||
|
load(1, tTypePrimitiveArray)
|
||||||
|
iconst(argumentIndex)
|
||||||
|
aload(tTypePrimitive)
|
||||||
|
store(localIndex, tTypePrimitive)
|
||||||
|
} else {
|
||||||
|
load(1, MAP_TYPE)
|
||||||
|
aconst(name.identity)
|
||||||
|
|
||||||
|
invokestatic(
|
||||||
|
MAP_INTRINSICS_TYPE.internalName,
|
||||||
|
"getOrFail",
|
||||||
|
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkcast(tType)
|
||||||
|
unbox()
|
||||||
|
store(localIndex, tTypePrimitive)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
|
||||||
* with [prepareVariable] first.
|
* with [prepareVariable] first.
|
||||||
*/
|
*/
|
||||||
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tTypePrimitive)
|
protected fun loadVariable(name: Symbol) {
|
||||||
|
val argumentIndex = argumentsIndexer.indexOf(name)
|
||||||
|
val localIndex = 2 + argumentIndex * tTypePrimitive.size
|
||||||
|
invokeMethodVisitor.load(localIndex, tTypePrimitive)
|
||||||
|
}
|
||||||
|
|
||||||
private fun unbox() = invokeMethodVisitor.run {
|
private fun unbox() = invokeMethodVisitor.run {
|
||||||
invokevirtual(
|
invokevirtual(
|
||||||
@ -231,102 +328,117 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
|
|||||||
invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false)
|
invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected fun visitVariables(node: MST): Unit = when (node) {
|
private fun visitVariables(
|
||||||
is Symbol -> prepareVariable(node.identity)
|
node: MST,
|
||||||
is MST.Unary -> visitVariables(node.value)
|
arrayMode: Boolean,
|
||||||
|
alreadyLoaded: MutableList<Symbol> = mutableListOf()
|
||||||
|
): Unit = when (node) {
|
||||||
|
is Symbol -> when (node) {
|
||||||
|
!in alreadyLoaded -> {
|
||||||
|
alreadyLoaded += node
|
||||||
|
prepareVariable(node, arrayMode)
|
||||||
|
}
|
||||||
|
else -> {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
|
||||||
|
|
||||||
is MST.Binary -> {
|
is MST.Binary -> {
|
||||||
visitVariables(node.left)
|
visitVariables(node.left, arrayMode, alreadyLoaded)
|
||||||
visitVariables(node.right)
|
visitVariables(node.right, arrayMode, alreadyLoaded)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> Unit
|
else -> Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
protected fun visitExpression(mst: MST): Unit = when (mst) {
|
private fun visitExpression(node: MST): Unit = when (node) {
|
||||||
is Symbol -> loadVariable(mst.identity)
|
is Symbol -> {
|
||||||
is MST.Numeric -> loadNumberConstant(mst.value)
|
val symbol = algebra.bindSymbolOrNull(node)
|
||||||
|
|
||||||
is MST.Unary -> when {
|
if (symbol != null)
|
||||||
algebra is NumericAlgebra && mst.value is MST.Numeric -> {
|
loadNumberConstant(symbol)
|
||||||
loadNumberConstant(
|
else
|
||||||
MST.Numeric(
|
loadVariable(node)
|
||||||
algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as MST.Numeric).value)),
|
|
||||||
).value,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> visitUnary(mst)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
|
||||||
|
|
||||||
|
is MST.Unary -> if (node.value is MST.Numeric)
|
||||||
|
loadNumberConstant(
|
||||||
|
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
|
||||||
|
)
|
||||||
|
else
|
||||||
|
visitUnary(node)
|
||||||
|
|
||||||
is MST.Binary -> when {
|
is MST.Binary -> when {
|
||||||
algebra is NumericAlgebra && mst.left is MST.Numeric && mst.right is MST.Numeric -> {
|
node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
|
||||||
loadNumberConstant(
|
algebra.binaryOperationFunction(node.operation)(
|
||||||
MST.Numeric(
|
algebra.number((node.left as MST.Numeric).value),
|
||||||
algebra.binaryOperationFunction(mst.operation)(
|
algebra.number((node.right as MST.Numeric).value),
|
||||||
algebra.number((mst.left as MST.Numeric).value),
|
),
|
||||||
algebra.number((mst.right as MST.Numeric).value),
|
)
|
||||||
),
|
|
||||||
).value,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
else -> visitBinary(mst)
|
else -> visitBinary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected open fun visitUnary(mst: MST.Unary) {
|
protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value)
|
||||||
visitExpression(mst.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
protected open fun visitBinary(mst: MST.Binary) {
|
protected open fun visitBinary(node: MST.Binary) {
|
||||||
visitExpression(mst.left)
|
visitExpression(node.left)
|
||||||
visitExpression(mst.right)
|
visitExpression(node.right)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected companion object {
|
protected companion object {
|
||||||
/**
|
/**
|
||||||
* ASM type for [java.lang.Number].
|
* ASM type for [java.lang.Number].
|
||||||
*/
|
*/
|
||||||
val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") }
|
val NUMBER_TYPE: Type = getObjectType("java/lang/Number")
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ASM type for [SymbolIndexer].
|
||||||
|
*/
|
||||||
|
val SYMBOL_INDEXER_TYPE: Type = getObjectType("space/kscience/kmath/expressions/SymbolIndexer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class DoubleAsmBuilder(target: MST) :
|
@UnstableKMathAPI
|
||||||
PrimitiveAsmBuilder<Double>(DoubleField, java.lang.Double::class.java, java.lang.Double.TYPE, target) {
|
internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, DoubleExpression>(
|
||||||
|
DoubleField,
|
||||||
|
java.lang.Double::class.java,
|
||||||
|
java.lang.Double.TYPE,
|
||||||
|
DoubleExpression::class.java,
|
||||||
|
target,
|
||||||
|
) {
|
||||||
|
|
||||||
private fun buildUnaryJavaMathCall(name: String) {
|
private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
||||||
invokeMethodVisitor.invokestatic(
|
MATH_TYPE.internalName,
|
||||||
MATH_TYPE.internalName,
|
name,
|
||||||
name,
|
getMethodDescriptor(tTypePrimitive, tTypePrimitive),
|
||||||
getMethodDescriptor(tTypePrimitive, tTypePrimitive),
|
false,
|
||||||
false,
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun buildBinaryJavaMathCall(name: String) {
|
@Suppress("SameParameterValue")
|
||||||
invokeMethodVisitor.invokestatic(
|
private fun buildBinaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
||||||
MATH_TYPE.internalName,
|
MATH_TYPE.internalName,
|
||||||
name,
|
name,
|
||||||
getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive),
|
getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive),
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
private fun buildUnaryKotlinMathCall(name: String) {
|
private fun buildUnaryKotlinMathCall(name: String) = invokeMethodVisitor.invokestatic(
|
||||||
invokeMethodVisitor.invokestatic(
|
MATH_KT_TYPE.internalName,
|
||||||
MATH_KT_TYPE.internalName,
|
name,
|
||||||
name,
|
getMethodDescriptor(tTypePrimitive, tTypePrimitive),
|
||||||
getMethodDescriptor(tTypePrimitive, tTypePrimitive),
|
false,
|
||||||
false,
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun visitUnary(mst: MST.Unary) {
|
override fun visitUnary(node: MST.Unary) {
|
||||||
super.visitUnary(mst)
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (mst.operation) {
|
when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DNEG)
|
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(DNEG)
|
||||||
GroupOps.PLUS_OPERATION -> Unit
|
GroupOps.PLUS_OPERATION -> Unit
|
||||||
PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt")
|
PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt")
|
||||||
TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin")
|
TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin")
|
||||||
@ -343,74 +455,86 @@ internal class DoubleAsmBuilder(target: MST) :
|
|||||||
ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh")
|
ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh")
|
||||||
ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp")
|
ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp")
|
||||||
ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log")
|
ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log")
|
||||||
else -> super.visitUnary(mst)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: MST.Binary) {
|
override fun visitBinary(node: MST.Binary) {
|
||||||
super.visitBinary(mst)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (mst.operation) {
|
when (node.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DADD)
|
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(DADD)
|
||||||
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DSUB)
|
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(DSUB)
|
||||||
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DMUL)
|
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(DMUL)
|
||||||
FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DDIV)
|
FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(DDIV)
|
||||||
PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow")
|
PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow")
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
private companion object {
|
||||||
val MATH_TYPE: Type by lazy { getObjectType("java/lang/Math") }
|
val MATH_TYPE: Type = getObjectType("java/lang/Math")
|
||||||
val MATH_KT_TYPE: Type by lazy { getObjectType("kotlin/math/MathKt") }
|
val MATH_KT_TYPE: Type = getObjectType("kotlin/math/MathKt")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
internal class IntAsmBuilder(target: MST) :
|
internal class IntAsmBuilder(target: MST) :
|
||||||
PrimitiveAsmBuilder<Int>(IntRing, Integer::class.java, Integer.TYPE, target) {
|
PrimitiveAsmBuilder<Int, IntExpression>(
|
||||||
override fun visitUnary(mst: MST.Unary) {
|
IntRing,
|
||||||
super.visitUnary(mst)
|
Integer::class.java,
|
||||||
|
Integer.TYPE,
|
||||||
|
IntExpression::class.java,
|
||||||
|
target
|
||||||
|
) {
|
||||||
|
override fun visitUnary(node: MST.Unary) {
|
||||||
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (mst.operation) {
|
when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.INEG)
|
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(INEG)
|
||||||
GroupOps.PLUS_OPERATION -> Unit
|
GroupOps.PLUS_OPERATION -> Unit
|
||||||
else -> super.visitUnary(mst)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: MST.Binary) {
|
override fun visitBinary(node: MST.Binary) {
|
||||||
super.visitBinary(mst)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (mst.operation) {
|
when (node.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IADD)
|
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(IADD)
|
||||||
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.ISUB)
|
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(ISUB)
|
||||||
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IMUL)
|
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(IMUL)
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class LongAsmBuilder(target: MST) :
|
@UnstableKMathAPI
|
||||||
PrimitiveAsmBuilder<Long>(LongRing, java.lang.Long::class.java, java.lang.Long.TYPE, target) {
|
internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpression>(
|
||||||
override fun visitUnary(mst: MST.Unary) {
|
LongRing,
|
||||||
super.visitUnary(mst)
|
java.lang.Long::class.java,
|
||||||
|
java.lang.Long.TYPE,
|
||||||
|
LongExpression::class.java,
|
||||||
|
target,
|
||||||
|
) {
|
||||||
|
override fun visitUnary(node: MST.Unary) {
|
||||||
|
super.visitUnary(node)
|
||||||
|
|
||||||
when (mst.operation) {
|
when (node.operation) {
|
||||||
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LNEG)
|
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(LNEG)
|
||||||
GroupOps.PLUS_OPERATION -> Unit
|
GroupOps.PLUS_OPERATION -> Unit
|
||||||
else -> super.visitUnary(mst)
|
else -> super.visitUnary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: MST.Binary) {
|
override fun visitBinary(node: MST.Binary) {
|
||||||
super.visitBinary(mst)
|
super.visitBinary(node)
|
||||||
|
|
||||||
when (mst.operation) {
|
when (node.operation) {
|
||||||
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LADD)
|
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(LADD)
|
||||||
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LSUB)
|
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(LSUB)
|
||||||
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LMUL)
|
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(LMUL)
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(node)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,10 @@ plugins {
|
|||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
filter { it.name.contains("test", true) }
|
filter { it.name.contains("test", true) }
|
||||||
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
|
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
|
||||||
.forEach { it.optIn("space.kscience.kmath.misc.PerformancePitfall") }
|
.forEach {
|
||||||
|
it.optIn("space.kscience.kmath.misc.PerformancePitfall")
|
||||||
|
it.optIn("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||||
|
}
|
||||||
|
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
@ -24,6 +25,81 @@ public fun interface Expression<T> {
|
|||||||
public operator fun invoke(arguments: Map<Symbol, T>): T
|
public operator fun invoke(arguments: Map<Symbol, T>): T
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialization of [Expression] for [Double] allowing better performance because of using array.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface DoubleExpression : Expression<Double> {
|
||||||
|
/**
|
||||||
|
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
||||||
|
*
|
||||||
|
* Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`,
|
||||||
|
* `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke].
|
||||||
|
*/
|
||||||
|
public val indexer: SymbolIndexer
|
||||||
|
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
|
this(DoubleArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param arguments the array of arguments.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
|
public operator fun invoke(arguments: DoubleArray): Double
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialization of [Expression] for [Int] allowing better performance because of using array.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface IntExpression : Expression<Int> {
|
||||||
|
/**
|
||||||
|
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
||||||
|
*
|
||||||
|
* Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`,
|
||||||
|
* `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke].
|
||||||
|
*/
|
||||||
|
public val indexer: SymbolIndexer
|
||||||
|
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, Int>): Int =
|
||||||
|
this(IntArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param arguments the array of arguments.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
|
public operator fun invoke(arguments: IntArray): Int
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialization of [Expression] for [Long] allowing better performance because of using array.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface LongExpression : Expression<Long> {
|
||||||
|
/**
|
||||||
|
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
||||||
|
*
|
||||||
|
* Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`,
|
||||||
|
* `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke].
|
||||||
|
*/
|
||||||
|
public val indexer: SymbolIndexer
|
||||||
|
|
||||||
|
public override operator fun invoke(arguments: Map<Symbol, Long>): Long =
|
||||||
|
this(LongArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param arguments the array of arguments.
|
||||||
|
* @return the value.
|
||||||
|
*/
|
||||||
|
public operator fun invoke(arguments: LongArray): Long
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression without providing any arguments.
|
* Calls this expression without providing any arguments.
|
||||||
*
|
*
|
||||||
@ -69,6 +145,62 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
private val EMPTY_DOUBLE_ARRAY = DoubleArray(0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression without providing any arguments.
|
||||||
|
*
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public operator fun DoubleExpression.invoke(): Double = this(EMPTY_DOUBLE_ARRAY)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param pairs the pairs of arguments to values.
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public operator fun DoubleExpression.invoke(vararg arguments: Double): Double = this(arguments)
|
||||||
|
|
||||||
|
private val EMPTY_INT_ARRAY = IntArray(0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression without providing any arguments.
|
||||||
|
*
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public operator fun IntExpression.invoke(): Int = this(EMPTY_INT_ARRAY)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param pairs the pairs of arguments to values.
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public operator fun IntExpression.invoke(vararg arguments: Int): Int = this(arguments)
|
||||||
|
|
||||||
|
private val EMPTY_LONG_ARRAY = LongArray(0)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression without providing any arguments.
|
||||||
|
*
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public operator fun LongExpression.invoke(): Long = this(EMPTY_LONG_ARRAY)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calls this expression from arguments.
|
||||||
|
*
|
||||||
|
* @param pairs the pairs of arguments to values.
|
||||||
|
* @return a value.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public operator fun LongExpression.invoke(vararg arguments: Long): Long = this(arguments)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A context for expression construction
|
* A context for expression construction
|
||||||
|
@ -3,6 +3,11 @@ plugins {
|
|||||||
id("ru.mipt.npm.gradle.common")
|
id("ru.mipt.npm.gradle.common")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kotlin.sourceSets
|
||||||
|
.filter { it.name.contains("test", true) }
|
||||||
|
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
|
||||||
|
.forEach { it.optIn("space.kscience.kmath.misc.UnstableKMathAPI") }
|
||||||
|
|
||||||
description = "Kotlin∇ integration module"
|
description = "Kotlin∇ integration module"
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
@ -145,7 +145,7 @@ internal object InternalGamma {
|
|||||||
}
|
}
|
||||||
|
|
||||||
when {
|
when {
|
||||||
n >= maxIterations -> throw error("Maximal iterations is exceeded $maxIterations")
|
n >= maxIterations -> error("Maximal iterations is exceeded $maxIterations")
|
||||||
sum.isInfinite() -> 1.0
|
sum.isInfinite() -> 1.0
|
||||||
else -> exp(-x + a * ln(x) - logGamma(a)) * sum
|
else -> exp(-x + a * ln(x) - logGamma(a)) * sum
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user