Completely specialized expression types for Int, Long, Double and compilation of MST to it

This commit is contained in:
Iaroslav Postovalov 2021-11-25 23:26:08 +07:00
parent c6a4721d64
commit 5ba7d74bd2
15 changed files with 615 additions and 301 deletions

View File

@ -38,6 +38,22 @@ internal class ExpressionsInterpretersBenchmark {
@Benchmark @Benchmark
fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole) fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole)
/**
* Benchmark case for [Expression] created with [compileToExpression].
*/
@Benchmark
fun asmPrimitiveExpressionArray(blackhole: Blackhole) {
val random = Random(0)
var sum = 0.0
val m = DoubleArray(1)
repeat(times) {
m[xIdx] = random.nextDouble()
sum += asmPrimitive(m)
}
blackhole.consume(sum)
}
/** /**
* Benchmark case for [Expression] created with [compileToExpression]. * Benchmark case for [Expression] created with [compileToExpression].
@ -82,7 +98,6 @@ internal class ExpressionsInterpretersBenchmark {
private companion object { private companion object {
private val x by symbol private val x by symbol
private val algebra = DoubleField
private const val times = 1_000_000 private const val times = 1_000_000
private val functional = DoubleField.expression { private val functional = DoubleField.expression {
@ -95,7 +110,10 @@ internal class ExpressionsInterpretersBenchmark {
} }
private val mst = node.toExpression(DoubleField) private val mst = node.toExpression(DoubleField)
private val asmPrimitive = node.compileToExpression(DoubleField) private val asmPrimitive = node.compileToExpression(DoubleField)
private val xIdx = asmPrimitive.indexer.indexOf(x)
private val asmGeneric = node.compileToExpression(DoubleField as Algebra<Double>) private val asmGeneric = node.compileToExpression(DoubleField as Algebra<Double>)
private val raw = Expression<Double> { args -> private val raw = Expression<Double> { args ->

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.ast
import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.MstExtendedField import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -17,10 +16,11 @@ fun main() {
x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x) x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x)
}.compileToExpression(DoubleField) }.compileToExpression(DoubleField)
val m = HashMap<Symbol, Double>() val m = DoubleArray(expr.indexer.symbols.size)
val xIdx = expr.indexer.indexOf(x)
repeat(10000000) { repeat(10000000) {
m[x] = 1.0 m[xIdx] = 1.0
expr(m) expr(m)
} }
} }

View File

@ -7,7 +7,9 @@ package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.Symbol.Companion.y
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
@ -16,11 +18,23 @@ import kotlin.test.assertFailsWith
internal class TestCompilerVariables { internal class TestCompilerVariables {
@Test @Test
fun testVariable() = runCompilerTest { fun testNoVariables() = runCompilerTest {
val expr = "0".parseMath().compileToExpression(IntRing)
assertEquals(0, expr())
}
@Test
fun testOneVariable() = runCompilerTest {
val expr = MstRing { x }.compileToExpression(IntRing) val expr = MstRing { x }.compileToExpression(IntRing)
assertEquals(1, expr(x to 1)) assertEquals(1, expr(x to 1))
} }
@Test
fun testTwoVariables() = runCompilerTest {
val expr = "y+x/y+x".parseMath().compileToExpression(DoubleField)
assertEquals(8.0, expr(x to 4.0, y to 2.0))
}
@Test @Test
fun testUndefinedVariableFails() = runCompilerTest { fun testUndefinedVariableFails() = runCompilerTest {
val expr = MstRing { x }.compileToExpression(IntRing) val expr = MstRing { x }.compileToExpression(IntRing)

View File

@ -201,8 +201,8 @@ internal open external class Module {
} }
@JsName("Instance") @JsName("Instance")
internal open external class Instance(module: Module, importObject: Any = definedExternally) { internal open external class Instance(module: Module, importObject: dynamic = definedExternally) {
open var exports: Any open var exports: dynamic
} }
@JsName("Memory") @JsName("Memory")

View File

@ -5,12 +5,11 @@
package space.kscience.kmath.wasm.internal package space.kscience.kmath.wasm.internal
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.internal.binaryen.* import space.kscience.kmath.internal.binaryen.*
import space.kscience.kmath.internal.webassembly.Instance import space.kscience.kmath.internal.webassembly.Instance
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.internal.binaryen.Module as BinaryenModule import space.kscience.kmath.internal.binaryen.Module as BinaryenModule
import space.kscience.kmath.internal.webassembly.Module as WasmModule import space.kscience.kmath.internal.webassembly.Module as WasmModule
@ -18,65 +17,17 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule
private val spreader = eval("(obj, args) => obj(...args)") private val spreader = eval("(obj, args) => obj(...args)")
@Suppress("UnsafeCastFromDynamic") @Suppress("UnsafeCastFromDynamic")
internal sealed class WasmBuilder<T : Number>( internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
protected val binaryenType: Type, protected val binaryenType: Type,
protected val algebra: Algebra<T>, protected val algebra: Algebra<T>,
protected val target: MST, protected val target: MST,
) { ) {
protected val keys: MutableList<Symbol> = mutableListOf() protected val keys: MutableList<Symbol> = mutableListOf()
lateinit var ctx: BinaryenModule protected lateinit var ctx: BinaryenModule
open fun visitSymbolic(mst: Symbol): ExpressionRef { abstract val instance: E
algebra.bindSymbolOrNull(mst)?.let { return visitNumeric(Numeric(it)) }
var idx = keys.indexOf(mst) protected val executable = run {
if (idx == -1) {
keys += mst
idx = keys.lastIndex
}
return ctx.local.get(idx, binaryenType)
}
abstract fun visitNumeric(mst: Numeric): ExpressionRef
protected open fun visitUnary(mst: Unary): ExpressionRef =
error("Unary operation ${mst.operation} not defined in $this")
protected open fun visitBinary(mst: Binary): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
protected fun visit(mst: MST): ExpressionRef = when (mst) {
is Symbol -> visitSymbolic(mst)
is Numeric -> visitNumeric(mst)
is Unary -> when {
algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric(
Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value)))
)
else -> visitUnary(mst)
}
is Binary -> when {
algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric(
Numeric(
algebra.binaryOperationFunction(mst.operation)
.invoke(
algebra.number((mst.left as Numeric).value),
algebra.number((mst.right as Numeric).value)
)
)
)
else -> visitBinary(mst)
}
}
val instance by lazy {
val c = WasmModule(with(createModule()) { val c = WasmModule(with(createModule()) {
ctx = this ctx = this
val expr = visit(target) val expr = visit(target)
@ -97,41 +48,93 @@ internal sealed class WasmBuilder<T : Number>(
res res
}) })
val i = Instance(c, js("{}") as Any) Instance(c, js("{}")).exports.executable
val symbols = keys }
keys.clear()
Expression<T> { args -> protected open fun visitSymbol(node: Symbol): ExpressionRef {
val params = symbols.map(args::getValue).toTypedArray() algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) }
spreader(i.exports.asDynamic().executable, params) as T
var idx = keys.indexOf(node)
if (idx == -1) {
keys += node
idx = keys.lastIndex
}
return ctx.local.get(idx, binaryenType)
}
protected abstract fun visitNumeric(node: Numeric): ExpressionRef
protected open fun visitUnary(node: Unary): ExpressionRef =
error("Unary operation ${node.operation} not defined in $this")
protected open fun visitBinary(mst: Binary): ExpressionRef =
error("Binary operation ${mst.operation} not defined in $this")
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
protected fun visit(node: MST): ExpressionRef = when (node) {
is Symbol -> visitSymbol(node)
is Numeric -> visitNumeric(node)
is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> visitNumeric(
Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value)))
)
else -> visitUnary(node)
}
is Binary -> when {
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric(
Numeric(
algebra.binaryOperationFunction(node.operation)
.invoke(
algebra.number((node.left as Numeric).value),
algebra.number((node.right as Numeric).value)
)
)
)
else -> visitBinary(node)
} }
} }
} }
internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleField, target) { @UnstableKMathAPI
override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions) internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double, DoubleExpression>(f64, DoubleField, target) {
override val instance by lazy {
object : DoubleExpression {
override val indexer = SimpleSymbolIndexer(keys)
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value) override fun invoke(arguments: DoubleArray) = spreader(executable, arguments).unsafeCast<Double>()
}
}
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { override fun createModule() = readBinary(f64StandardFunctions)
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
GroupOps.PLUS_OPERATION -> visit(mst.value) override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value))
TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), f64) GroupOps.PLUS_OPERATION -> visit(node.value)
TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), f64) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value))
TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), f64) TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(node.value)), f64)
TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), f64) TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(node.value)), f64)
ExponentialOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), f64) TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(node.value)), f64)
ExponentialOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), f64) TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(node.value)), f64)
ExponentialOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), f64) TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(node.value)), f64)
ExponentialOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), f64) TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(node.value)), f64)
ExponentialOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), f64) ExponentialOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(node.value)), f64)
ExponentialOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), f64) ExponentialOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(node.value)), f64)
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), f64) ExponentialOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(node.value)), f64)
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), f64) ExponentialOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(node.value)), f64)
else -> super.visitUnary(mst) ExponentialOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(node.value)), f64)
ExponentialOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(node.value)), f64)
ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(node.value)), f64)
ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(node.value)), f64)
else -> super.visitUnary(node)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
@ -144,13 +147,22 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
} }
} }
internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, target) { @UnstableKMathAPI
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value) internal class IntWasmBuilder(target: MST) : WasmBuilder<Int, IntExpression>(i32, IntRing, target) {
override val instance by lazy {
object : IntExpression {
override val indexer = SimpleSymbolIndexer(keys)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { override fun invoke(arguments: IntArray) = spreader(executable, arguments).unsafeCast<Int>()
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) }
GroupOps.PLUS_OPERATION -> visit(mst.value) }
else -> super.visitUnary(mst)
override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value)
override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) {
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value))
GroupOps.PLUS_OPERATION -> visit(node.value)
else -> super.visitUnary(node)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {

View File

@ -3,13 +3,12 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:Suppress("UNUSED_PARAMETER")
package space.kscience.kmath.wasm package space.kscience.kmath.wasm
import space.kscience.kmath.estree.compileWith import space.kscience.kmath.estree.compileWith
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
@ -22,7 +21,7 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = compileWith(algebra) public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance
/** /**
@ -50,7 +49,7 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = compileWith(algebra) public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleWasmBuilder(this).instance
/** /**

View File

@ -8,11 +8,8 @@
package space.kscience.kmath.asm package space.kscience.kmath.asm
import space.kscience.kmath.asm.internal.* import space.kscience.kmath.asm.internal.*
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MST.* import space.kscience.kmath.expressions.MST.*
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
@ -48,7 +45,13 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
loadVariable(node.identity) loadVariable(node.identity)
} }
is Numeric -> loadNumberConstant(node.value) is Numeric -> if (algebra is NumericAlgebra) {
if (Number::class.java.isAssignableFrom(type))
loadNumberConstant(algebra.number(node.value) as Number)
else
loadObjectConstant(algebra.number(node.value))
} else
error("Numeric nodes are not supported by $this")
is Unary -> when { is Unary -> when {
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
@ -121,13 +124,15 @@ public inline fun <reified T : Any> MST.compile(algebra: Algebra<T>, vararg argu
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public fun MST.compileToExpression(algebra: IntRing): Expression<Int> = IntAsmBuilder(this).instance @UnstableKMathAPI
public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra).invoke(arguments)
@ -136,6 +141,7 @@ public fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int =
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI
public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int = public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int =
compileToExpression(algebra)(*arguments) compileToExpression(algebra)(*arguments)
@ -145,7 +151,8 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): I
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public fun MST.compileToExpression(algebra: LongRing): Expression<Long> = LongAsmBuilder(this).instance @UnstableKMathAPI
public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance
/** /**
@ -153,6 +160,7 @@ public fun MST.compileToExpression(algebra: LongRing): Expression<Long> = LongAs
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI
public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long = public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra).invoke(arguments)
@ -162,6 +170,7 @@ public fun MST.compile(algebra: LongRing, arguments: Map<Symbol, Long>): Long =
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI
public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>): Long = public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>): Long =
compileToExpression(algebra)(*arguments) compileToExpression(algebra)(*arguments)
@ -171,13 +180,15 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair<Symbol, Long>):
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = DoubleAsmBuilder(this).instance @UnstableKMathAPI
public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance
/** /**
* Compile given MST to expression and evaluate it against [arguments]. * Compile given MST to expression and evaluate it against [arguments].
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
compileToExpression(algebra).invoke(arguments) compileToExpression(algebra).invoke(arguments)
@ -186,5 +197,6 @@ public fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Do
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
@UnstableKMathAPI
public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double = public fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compileToExpression(algebra).invoke(*arguments) compileToExpression(algebra).invoke(*arguments)

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.asm.internal package space.kscience.kmath.asm.internal
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.Type.getObjectType
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
internal abstract class AsmBuilder { internal abstract class AsmBuilder {
@ -22,31 +23,31 @@ internal abstract class AsmBuilder {
/** /**
* ASM type for [Expression]. * ASM type for [Expression].
*/ */
val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Expression") } val EXPRESSION_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Expression")
/** /**
* ASM type for [java.util.Map]. * ASM type for [java.util.Map].
*/ */
val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } val MAP_TYPE: Type = getObjectType("java/util/Map")
/** /**
* ASM type for [java.lang.Object]. * ASM type for [java.lang.Object].
*/ */
val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } val OBJECT_TYPE: Type = getObjectType("java/lang/Object")
/** /**
* ASM type for [java.lang.String]. * ASM type for [java.lang.String].
*/ */
val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } val STRING_TYPE: Type = getObjectType("java/lang/String")
/** /**
* ASM type for MapIntrinsics. * ASM type for MapIntrinsics.
*/ */
val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") } val MAP_INTRINSICS_TYPE: Type = getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics")
/** /**
* ASM Type for [space.kscience.kmath.expressions.Symbol]. * ASM Type for [space.kscience.kmath.expressions.Symbol].
*/ */
val SYMBOL_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Symbol") } val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
} }
} }

View File

@ -1,6 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm.internal

View File

@ -78,7 +78,7 @@ internal class GenericAsmBuilder<T>(
) )
visitMethod( visitMethod(
ACC_PUBLIC or ACC_FINAL, ACC_PUBLIC,
"invoke", "invoke",
getMethodDescriptor(tType, MAP_TYPE), getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
@ -116,7 +116,7 @@ internal class GenericAsmBuilder<T>(
} }
visitMethod( visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, ACC_PUBLIC or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke", "invoke",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null, null,
@ -156,7 +156,7 @@ internal class GenericAsmBuilder<T>(
) )
visitMethod( visitMethod(
ACC_PUBLIC, ACC_PUBLIC or ACC_SYNTHETIC,
"<init>", "<init>",
getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }),
null, null,
@ -176,7 +176,7 @@ internal class GenericAsmBuilder<T>(
} }
label() label()
visitInsn(RETURN) areturn(VOID_TYPE)
val l4 = label() val l4 = label()
visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) visitLocalVariable("this", classType.descriptor, null, l0, l4, 0)
@ -209,10 +209,10 @@ internal class GenericAsmBuilder<T>(
*/ */
fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run {
val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex
invokeMethodVisitor.load(0, classType) load(0, classType)
getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor)
iconst(idx) iconst(idx)
visitInsn(AALOAD) aload(OBJECT_TYPE)
if (type != OBJECT_TYPE) checkcast(type) if (type != OBJECT_TYPE) checkcast(type)
} }
@ -320,6 +320,6 @@ internal class GenericAsmBuilder<T>(
/** /**
* ASM type for array of [java.lang.Object]. * ASM type for array of [java.lang.Object].
*/ */
val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } val OBJECT_ARRAY_TYPE: Type = getType("[Ljava/lang/Object;")
} }
} }

View File

@ -6,38 +6,49 @@
package space.kscience.kmath.asm.internal package space.kscience.kmath.asm.internal
import org.objectweb.asm.ClassWriter import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Opcodes import org.objectweb.asm.FieldVisitor
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type import org.objectweb.asm.Type
import org.objectweb.asm.Type.* import org.objectweb.asm.Type.*
import org.objectweb.asm.commons.InstructionAdapter import org.objectweb.asm.commons.InstructionAdapter
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.*
import space.kscience.kmath.expressions.MST import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import java.lang.invoke.MethodHandles import java.lang.invoke.MethodHandles
import java.lang.invoke.MethodType import java.lang.invoke.MethodType
import java.nio.file.Paths import java.nio.file.Paths
import kotlin.io.path.writeBytes import kotlin.io.path.writeBytes
internal sealed class PrimitiveAsmBuilder<T : Number>( @UnstableKMathAPI
protected val algebra: Algebra<T>, internal sealed class PrimitiveAsmBuilder<T : Number, out E : Expression<T>>(
protected val algebra: NumericAlgebra<T>,
classOfT: Class<*>, classOfT: Class<*>,
protected val classOfTPrimitive: Class<*>, protected val classOfTPrimitive: Class<*>,
expressionParent: Class<E>,
protected val target: MST, protected val target: MST,
) : AsmBuilder() { ) : AsmBuilder() {
private val className: String = buildName(target) private val className: String = buildName(target)
/** /**
* ASM type for [T]. * ASM type for [tType].
*/ */
private val tType: Type = classOfT.asm private val tType: Type = classOfT.asm
/** /**
* ASM type for [T]. * ASM type for [classOfTPrimitive].
*/ */
protected val tTypePrimitive: Type = classOfTPrimitive.asm protected val tTypePrimitive: Type = classOfTPrimitive.asm
/**
* ASM type for array of [classOfTPrimitive].
*/
protected val tTypePrimitiveArray: Type = getType("[" + classOfTPrimitive.asm.descriptor)
/**
* ASM type for expression parent.
*/
private val expressionParentType = expressionParent.asm
/** /**
* ASM type for new class. * ASM type for new class.
*/ */
@ -49,58 +60,91 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
protected lateinit var invokeMethodVisitor: InstructionAdapter protected lateinit var invokeMethodVisitor: InstructionAdapter
/** /**
* Local variables indices are indices of symbols in this list. * Indexer for arguments in [target].
*/ */
private val argumentsLocals = mutableListOf<String>() private val argumentsIndexer = mutableListOf<Symbol>()
/** /**
* Subclasses, loads and instantiates [Expression] for given parameters. * Subclasses, loads and instantiates [Expression] for given parameters.
* *
* The built instance is cached. * The built instance is cached.
*/ */
@Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE") @Suppress("UNCHECKED_CAST")
val instance: Expression<T> by lazy { val instance: E by lazy {
val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) {
visit( visit(
Opcodes.V1_8, V1_8,
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, ACC_PUBLIC or ACC_FINAL or ACC_SUPER,
classType.internalName, classType.internalName,
"${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", "${OBJECT_TYPE.descriptor}${expressionParentType.descriptor}",
OBJECT_TYPE.internalName, OBJECT_TYPE.internalName,
arrayOf(EXPRESSION_TYPE.internalName), arrayOf(expressionParentType.internalName),
) )
visitField(
access = ACC_PRIVATE or ACC_FINAL,
name = "indexer",
descriptor = SYMBOL_INDEXER_TYPE.descriptor,
signature = null,
value = null,
block = FieldVisitor::visitEnd,
)
visitMethod( visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, ACC_PUBLIC,
"invoke", "getIndexer",
getMethodDescriptor(tType, MAP_TYPE), getMethodDescriptor(SYMBOL_INDEXER_TYPE),
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", null,
null, null,
).instructionAdapter { ).instructionAdapter {
invokeMethodVisitor = this
visitCode() visitCode()
val preparingVariables = label() val start = label()
visitVariables(target) load(0, classType)
val expressionResult = label() getfield(classType.internalName, "indexer", SYMBOL_INDEXER_TYPE.descriptor)
visitExpression(target) areturn(SYMBOL_INDEXER_TYPE)
box()
areturn(tType)
val end = label() val end = label()
visitLocalVariable( visitLocalVariable(
"this", "this",
classType.descriptor, classType.descriptor,
null, null,
preparingVariables, start,
end,
0,
)
visitMaxs(0, 0)
visitEnd()
}
visitMethod(
ACC_PUBLIC,
"invoke",
getMethodDescriptor(tTypePrimitive, tTypePrimitiveArray),
null,
null,
).instructionAdapter {
invokeMethodVisitor = this
visitCode()
val start = label()
visitVariables(target, arrayMode = true)
visitExpression(target)
areturn(tTypePrimitive)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
start,
end, end,
0, 0,
) )
visitLocalVariable( visitLocalVariable(
"arguments", "arguments",
MAP_TYPE.descriptor, tTypePrimitiveArray.descriptor,
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", null,
preparingVariables, start,
end, end,
1, 1,
) )
@ -110,7 +154,45 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
} }
visitMethod( visitMethod(
Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, ACC_PUBLIC or ACC_FINAL,
"invoke",
getMethodDescriptor(tType, MAP_TYPE),
"(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}",
null,
).instructionAdapter {
invokeMethodVisitor = this
visitCode()
val start = label()
visitVariables(target, arrayMode = false)
visitExpression(target)
box()
areturn(tType)
val end = label()
visitLocalVariable(
"this",
classType.descriptor,
null,
start,
end,
0,
)
visitLocalVariable(
"arguments",
MAP_TYPE.descriptor,
"L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;",
start,
end,
1,
)
visitMaxs(0, 0)
visitEnd()
}
visitMethod(
ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC,
"invoke", "invoke",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), getMethodDescriptor(OBJECT_TYPE, MAP_TYPE),
null, null,
@ -138,21 +220,22 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
} }
visitMethod( visitMethod(
Opcodes.ACC_PUBLIC, ACC_PUBLIC or ACC_SYNTHETIC,
"<init>", "<init>",
getMethodDescriptor(VOID_TYPE), getMethodDescriptor(VOID_TYPE, SYMBOL_INDEXER_TYPE),
null, null,
null, null,
).instructionAdapter { ).instructionAdapter {
val start = label() val start = label()
load(0, classType) load(0, classType)
invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false) invokespecial(OBJECT_TYPE.internalName, "<init>", getMethodDescriptor(VOID_TYPE), false)
label()
load(0, classType) load(0, classType)
label() load(1, SYMBOL_INDEXER_TYPE)
visitInsn(Opcodes.RETURN) putfield(classType.internalName, "indexer", SYMBOL_INDEXER_TYPE.descriptor)
areturn(VOID_TYPE)
val end = label() val end = label()
visitLocalVariable("this", classType.descriptor, null, start, end, 0) visitLocalVariable("this", classType.descriptor, null, start, end, 0)
visitLocalVariable("indexer", SYMBOL_INDEXER_TYPE.descriptor, null, start, end, 1)
visitMaxs(0, 0) visitMaxs(0, 0)
visitEnd() visitEnd()
} }
@ -166,14 +249,16 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
Paths.get("${className.split('.').last()}.class").writeBytes(binary) Paths.get("${className.split('.').last()}.class").writeBytes(binary)
MethodHandles.publicLookup().findConstructor(cls, MethodType.methodType(Void.TYPE))() as Expression<T> MethodHandles
.publicLookup()
.findConstructor(cls, MethodType.methodType(Void.TYPE, SymbolIndexer::class.java))
.invoke(SimpleSymbolIndexer(argumentsIndexer)) as E
} }
/** /**
* Either loads a numeric constant [value] from the class's constants field or boxes a primitive * Loads a numeric constant [value] from the class's constants.
* constant from the constant pool.
*/ */
fun loadNumberConstant(value: Number) { protected fun loadNumberConstant(value: Number) {
when (tTypePrimitive) { when (tTypePrimitive) {
BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt())
DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble())
@ -185,38 +270,50 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
} }
/** /**
* Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using * Stores value variable [name] into a local. Should be called before using [loadVariable]. Should be called only
* [loadVariable]. * once for a variable.
*/ */
fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { protected fun prepareVariable(name: Symbol, arrayMode: Boolean): Unit = invokeMethodVisitor.run {
if (name in argumentsLocals) return@run var argumentIndex = argumentsIndexer.indexOf(name)
load(1, MAP_TYPE)
aconst(name)
invokestatic( if (argumentIndex == -1) {
MAP_INTRINSICS_TYPE.internalName, argumentsIndexer += name
"getOrFail", argumentIndex = argumentsIndexer.lastIndex
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false,
)
checkcast(tType)
var idx = argumentsLocals.indexOf(name)
if (idx == -1) {
argumentsLocals += name
idx = argumentsLocals.lastIndex
} }
unbox() val localIndex = 2 + argumentIndex * tTypePrimitive.size
store(2 + idx, tTypePrimitive)
if (arrayMode) {
load(1, tTypePrimitiveArray)
iconst(argumentIndex)
aload(tTypePrimitive)
store(localIndex, tTypePrimitive)
} else {
load(1, MAP_TYPE)
aconst(name.identity)
invokestatic(
MAP_INTRINSICS_TYPE.internalName,
"getOrFail",
getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE),
false,
)
checkcast(tType)
unbox()
store(localIndex, tTypePrimitive)
}
} }
/** /**
* Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored
* with [prepareVariable] first. * with [prepareVariable] first.
*/ */
fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tTypePrimitive) protected fun loadVariable(name: Symbol) {
val argumentIndex = argumentsIndexer.indexOf(name)
val localIndex = 2 + argumentIndex * tTypePrimitive.size
invokeMethodVisitor.load(localIndex, tTypePrimitive)
}
private fun unbox() = invokeMethodVisitor.run { private fun unbox() = invokeMethodVisitor.run {
invokevirtual( invokevirtual(
@ -231,102 +328,117 @@ internal sealed class PrimitiveAsmBuilder<T : Number>(
invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false) invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false)
} }
protected fun visitVariables(node: MST): Unit = when (node) { private fun visitVariables(
is Symbol -> prepareVariable(node.identity) node: MST,
is MST.Unary -> visitVariables(node.value) arrayMode: Boolean,
alreadyLoaded: MutableList<Symbol> = mutableListOf()
): Unit = when (node) {
is Symbol -> when (node) {
!in alreadyLoaded -> {
alreadyLoaded += node
prepareVariable(node, arrayMode)
}
else -> {
}
}
is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded)
is MST.Binary -> { is MST.Binary -> {
visitVariables(node.left) visitVariables(node.left, arrayMode, alreadyLoaded)
visitVariables(node.right) visitVariables(node.right, arrayMode, alreadyLoaded)
} }
else -> Unit else -> Unit
} }
protected fun visitExpression(mst: MST): Unit = when (mst) { private fun visitExpression(node: MST): Unit = when (node) {
is Symbol -> loadVariable(mst.identity) is Symbol -> {
is MST.Numeric -> loadNumberConstant(mst.value) val symbol = algebra.bindSymbolOrNull(node)
is MST.Unary -> when { if (symbol != null)
algebra is NumericAlgebra && mst.value is MST.Numeric -> { loadNumberConstant(symbol)
loadNumberConstant( else
MST.Numeric( loadVariable(node)
algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as MST.Numeric).value)),
).value,
)
}
else -> visitUnary(mst)
} }
is MST.Numeric -> loadNumberConstant(algebra.number(node.value))
is MST.Unary -> if (node.value is MST.Numeric)
loadNumberConstant(
algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)),
)
else
visitUnary(node)
is MST.Binary -> when { is MST.Binary -> when {
algebra is NumericAlgebra && mst.left is MST.Numeric && mst.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant(
loadNumberConstant( algebra.binaryOperationFunction(node.operation)(
MST.Numeric( algebra.number((node.left as MST.Numeric).value),
algebra.binaryOperationFunction(mst.operation)( algebra.number((node.right as MST.Numeric).value),
algebra.number((mst.left as MST.Numeric).value), ),
algebra.number((mst.right as MST.Numeric).value), )
),
).value,
)
}
else -> visitBinary(mst) else -> visitBinary(node)
} }
} }
protected open fun visitUnary(mst: MST.Unary) { protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value)
visitExpression(mst.value)
}
protected open fun visitBinary(mst: MST.Binary) { protected open fun visitBinary(node: MST.Binary) {
visitExpression(mst.left) visitExpression(node.left)
visitExpression(mst.right) visitExpression(node.right)
} }
protected companion object { protected companion object {
/** /**
* ASM type for [java.lang.Number]. * ASM type for [java.lang.Number].
*/ */
val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") } val NUMBER_TYPE: Type = getObjectType("java/lang/Number")
/**
* ASM type for [SymbolIndexer].
*/
val SYMBOL_INDEXER_TYPE: Type = getObjectType("space/kscience/kmath/expressions/SymbolIndexer")
} }
} }
internal class DoubleAsmBuilder(target: MST) : @UnstableKMathAPI
PrimitiveAsmBuilder<Double>(DoubleField, java.lang.Double::class.java, java.lang.Double.TYPE, target) { internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder<Double, DoubleExpression>(
DoubleField,
java.lang.Double::class.java,
java.lang.Double.TYPE,
DoubleExpression::class.java,
target,
) {
private fun buildUnaryJavaMathCall(name: String) { private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
invokeMethodVisitor.invokestatic( MATH_TYPE.internalName,
MATH_TYPE.internalName, name,
name, getMethodDescriptor(tTypePrimitive, tTypePrimitive),
getMethodDescriptor(tTypePrimitive, tTypePrimitive), false,
false, )
)
}
private fun buildBinaryJavaMathCall(name: String) { @Suppress("SameParameterValue")
invokeMethodVisitor.invokestatic( private fun buildBinaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic(
MATH_TYPE.internalName, MATH_TYPE.internalName,
name, name,
getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive), getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive),
false, false,
) )
}
private fun buildUnaryKotlinMathCall(name: String) { private fun buildUnaryKotlinMathCall(name: String) = invokeMethodVisitor.invokestatic(
invokeMethodVisitor.invokestatic( MATH_KT_TYPE.internalName,
MATH_KT_TYPE.internalName, name,
name, getMethodDescriptor(tTypePrimitive, tTypePrimitive),
getMethodDescriptor(tTypePrimitive, tTypePrimitive), false,
false, )
)
}
override fun visitUnary(mst: MST.Unary) { override fun visitUnary(node: MST.Unary) {
super.visitUnary(mst) super.visitUnary(node)
when (mst.operation) { when (node.operation) {
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DNEG) GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(DNEG)
GroupOps.PLUS_OPERATION -> Unit GroupOps.PLUS_OPERATION -> Unit
PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt") PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt")
TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin") TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin")
@ -343,74 +455,86 @@ internal class DoubleAsmBuilder(target: MST) :
ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh") ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh")
ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp") ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp")
ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log") ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log")
else -> super.visitUnary(mst) else -> super.visitUnary(node)
} }
} }
override fun visitBinary(mst: MST.Binary) { override fun visitBinary(node: MST.Binary) {
super.visitBinary(mst) super.visitBinary(node)
when (mst.operation) { when (node.operation) {
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DADD) GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(DADD)
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DSUB) GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(DSUB)
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DMUL) RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(DMUL)
FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DDIV) FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(DDIV)
PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow") PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow")
else -> super.visitBinary(mst) else -> super.visitBinary(node)
} }
} }
companion object { private companion object {
val MATH_TYPE: Type by lazy { getObjectType("java/lang/Math") } val MATH_TYPE: Type = getObjectType("java/lang/Math")
val MATH_KT_TYPE: Type by lazy { getObjectType("kotlin/math/MathKt") } val MATH_KT_TYPE: Type = getObjectType("kotlin/math/MathKt")
} }
} }
@UnstableKMathAPI
internal class IntAsmBuilder(target: MST) : internal class IntAsmBuilder(target: MST) :
PrimitiveAsmBuilder<Int>(IntRing, Integer::class.java, Integer.TYPE, target) { PrimitiveAsmBuilder<Int, IntExpression>(
override fun visitUnary(mst: MST.Unary) { IntRing,
super.visitUnary(mst) Integer::class.java,
Integer.TYPE,
IntExpression::class.java,
target
) {
override fun visitUnary(node: MST.Unary) {
super.visitUnary(node)
when (mst.operation) { when (node.operation) {
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.INEG) GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(INEG)
GroupOps.PLUS_OPERATION -> Unit GroupOps.PLUS_OPERATION -> Unit
else -> super.visitUnary(mst) else -> super.visitUnary(node)
} }
} }
override fun visitBinary(mst: MST.Binary) { override fun visitBinary(node: MST.Binary) {
super.visitBinary(mst) super.visitBinary(node)
when (mst.operation) { when (node.operation) {
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IADD) GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(IADD)
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.ISUB) GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(ISUB)
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IMUL) RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(IMUL)
else -> super.visitBinary(mst) else -> super.visitBinary(node)
} }
} }
} }
internal class LongAsmBuilder(target: MST) : @UnstableKMathAPI
PrimitiveAsmBuilder<Long>(LongRing, java.lang.Long::class.java, java.lang.Long.TYPE, target) { internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder<Long, LongExpression>(
override fun visitUnary(mst: MST.Unary) { LongRing,
super.visitUnary(mst) java.lang.Long::class.java,
java.lang.Long.TYPE,
LongExpression::class.java,
target,
) {
override fun visitUnary(node: MST.Unary) {
super.visitUnary(node)
when (mst.operation) { when (node.operation) {
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LNEG) GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(LNEG)
GroupOps.PLUS_OPERATION -> Unit GroupOps.PLUS_OPERATION -> Unit
else -> super.visitUnary(mst) else -> super.visitUnary(node)
} }
} }
override fun visitBinary(mst: MST.Binary) { override fun visitBinary(node: MST.Binary) {
super.visitBinary(mst) super.visitBinary(node)
when (mst.operation) { when (node.operation) {
GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LADD) GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(LADD)
GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LSUB) GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(LSUB)
RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LMUL) RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(LMUL)
else -> super.visitBinary(mst) else -> super.visitBinary(node)
} }
} }
} }

View File

@ -8,7 +8,10 @@ plugins {
kotlin.sourceSets { kotlin.sourceSets {
filter { it.name.contains("test", true) } filter { it.name.contains("test", true) }
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings) .map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
.forEach { it.optIn("space.kscience.kmath.misc.PerformancePitfall") } .forEach {
it.optIn("space.kscience.kmath.misc.PerformancePitfall")
it.optIn("space.kscience.kmath.misc.UnstableKMathAPI")
}
commonMain { commonMain {
dependencies { dependencies {

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
@ -24,6 +25,81 @@ public fun interface Expression<T> {
public operator fun invoke(arguments: Map<Symbol, T>): T public operator fun invoke(arguments: Map<Symbol, T>): T
} }
/**
* Specialization of [Expression] for [Double] allowing better performance because of using array.
*/
@UnstableKMathAPI
public interface DoubleExpression : Expression<Double> {
/**
* The indexer of this expression's arguments that should be used to build array for [invoke].
*
* Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`,
* `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke].
*/
public val indexer: SymbolIndexer
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
this(DoubleArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) })
/**
* Calls this expression from arguments.
*
* @param arguments the array of arguments.
* @return the value.
*/
public operator fun invoke(arguments: DoubleArray): Double
}
/**
* Specialization of [Expression] for [Int] allowing better performance because of using array.
*/
@UnstableKMathAPI
public interface IntExpression : Expression<Int> {
/**
* The indexer of this expression's arguments that should be used to build array for [invoke].
*
* Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`,
* `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke].
*/
public val indexer: SymbolIndexer
public override operator fun invoke(arguments: Map<Symbol, Int>): Int =
this(IntArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) })
/**
* Calls this expression from arguments.
*
* @param arguments the array of arguments.
* @return the value.
*/
public operator fun invoke(arguments: IntArray): Int
}
/**
* Specialization of [Expression] for [Long] allowing better performance because of using array.
*/
@UnstableKMathAPI
public interface LongExpression : Expression<Long> {
/**
* The indexer of this expression's arguments that should be used to build array for [invoke].
*
* Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`,
* `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke].
*/
public val indexer: SymbolIndexer
public override operator fun invoke(arguments: Map<Symbol, Long>): Long =
this(LongArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) })
/**
* Calls this expression from arguments.
*
* @param arguments the array of arguments.
* @return the value.
*/
public operator fun invoke(arguments: LongArray): Long
}
/** /**
* Calls this expression without providing any arguments. * Calls this expression without providing any arguments.
* *
@ -69,6 +145,62 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
} }
) )
private val EMPTY_DOUBLE_ARRAY = DoubleArray(0)
/**
* Calls this expression without providing any arguments.
*
* @return a value.
*/
@UnstableKMathAPI
public operator fun DoubleExpression.invoke(): Double = this(EMPTY_DOUBLE_ARRAY)
/**
* Calls this expression from arguments.
*
* @param pairs the pairs of arguments to values.
* @return a value.
*/
@UnstableKMathAPI
public operator fun DoubleExpression.invoke(vararg arguments: Double): Double = this(arguments)
private val EMPTY_INT_ARRAY = IntArray(0)
/**
* Calls this expression without providing any arguments.
*
* @return a value.
*/
@UnstableKMathAPI
public operator fun IntExpression.invoke(): Int = this(EMPTY_INT_ARRAY)
/**
* Calls this expression from arguments.
*
* @param pairs the pairs of arguments to values.
* @return a value.
*/
@UnstableKMathAPI
public operator fun IntExpression.invoke(vararg arguments: Int): Int = this(arguments)
private val EMPTY_LONG_ARRAY = LongArray(0)
/**
* Calls this expression without providing any arguments.
*
* @return a value.
*/
@UnstableKMathAPI
public operator fun LongExpression.invoke(): Long = this(EMPTY_LONG_ARRAY)
/**
* Calls this expression from arguments.
*
* @param pairs the pairs of arguments to values.
* @return a value.
*/
@UnstableKMathAPI
public operator fun LongExpression.invoke(vararg arguments: Long): Long = this(arguments)
/** /**
* A context for expression construction * A context for expression construction

View File

@ -3,6 +3,11 @@ plugins {
id("ru.mipt.npm.gradle.common") id("ru.mipt.npm.gradle.common")
} }
kotlin.sourceSets
.filter { it.name.contains("test", true) }
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
.forEach { it.optIn("space.kscience.kmath.misc.UnstableKMathAPI") }
description = "Kotlin∇ integration module" description = "Kotlin∇ integration module"
dependencies { dependencies {

View File

@ -145,7 +145,7 @@ internal object InternalGamma {
} }
when { when {
n >= maxIterations -> throw error("Maximal iterations is exceeded $maxIterations") n >= maxIterations -> error("Maximal iterations is exceeded $maxIterations")
sum.isInfinite() -> 1.0 sum.isInfinite() -> 1.0
else -> exp(-x + a * ln(x) - logGamma(a)) * sum else -> exp(-x + a * ln(x) - logGamma(a)) * sum
} }