From 5ba7d74bd249a02e742ad2e0bbcde83bf4bc0858 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 25 Nov 2021 23:26:08 +0700 Subject: [PATCH] Completely specialized expression types for Int, Long, Double and compilation of MST to it --- .../ExpressionsInterpretersBenchmark.kt | 20 +- .../space/kscience/kmath/ast/expressions.kt | 6 +- .../kmath/ast/TestCompilerVariables.kt | 16 +- .../lib.dom.WebAssembly.module_dukat.kt | 4 +- .../kmath/wasm/internal/WasmBuilder.kt | 190 +++---- .../kotlin/space/kscience/kmath/wasm/wasm.kt | 11 +- .../kotlin/space/kscience/kmath/asm/asm.kt | 28 +- .../kscience/kmath/asm/internal/AsmBuilder.kt | 13 +- .../asm/internal/ByteArrayClassLoader.kt | 6 - .../kmath/asm/internal/GenericAsmBuilder.kt | 14 +- .../kmath/asm/internal/PrimitiveAsmBuilder.kt | 464 +++++++++++------- kmath-core/build.gradle.kts | 5 +- .../kscience/kmath/expressions/Expression.kt | 132 +++++ kmath-kotlingrad/build.gradle.kts | 5 + .../kscience/kmath/internal/InternalGamma.kt | 2 +- 15 files changed, 615 insertions(+), 301 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt index f3a52ab5f..db3524e67 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt @@ -38,6 +38,22 @@ internal class ExpressionsInterpretersBenchmark { @Benchmark fun asmGenericExpression(blackhole: Blackhole) = invokeAndSum(asmGeneric, blackhole) + /** + * Benchmark case for [Expression] created with [compileToExpression]. + */ + @Benchmark + fun asmPrimitiveExpressionArray(blackhole: Blackhole) { + val random = Random(0) + var sum = 0.0 + val m = DoubleArray(1) + + repeat(times) { + m[xIdx] = random.nextDouble() + sum += asmPrimitive(m) + } + + blackhole.consume(sum) + } /** * Benchmark case for [Expression] created with [compileToExpression]. @@ -82,7 +98,6 @@ internal class ExpressionsInterpretersBenchmark { private companion object { private val x by symbol - private val algebra = DoubleField private const val times = 1_000_000 private val functional = DoubleField.expression { @@ -95,7 +110,10 @@ internal class ExpressionsInterpretersBenchmark { } private val mst = node.toExpression(DoubleField) + private val asmPrimitive = node.compileToExpression(DoubleField) + private val xIdx = asmPrimitive.indexer.indexOf(x) + private val asmGeneric = node.compileToExpression(DoubleField as Algebra) private val raw = Expression { args -> diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt index ffafdc99d..907f1bbe4 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt @@ -7,7 +7,6 @@ package space.kscience.kmath.ast import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.MstExtendedField -import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol.Companion.x import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke @@ -17,10 +16,11 @@ fun main() { x * 2.0 + number(2.0) / x - number(16.0) + asinh(x) / sin(x) }.compileToExpression(DoubleField) - val m = HashMap() + val m = DoubleArray(expr.indexer.symbols.size) + val xIdx = expr.indexer.indexOf(x) repeat(10000000) { - m[x] = 1.0 + m[xIdx] = 1.0 expr(m) } } diff --git a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt index af1a2e338..859e56032 100644 --- a/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt +++ b/kmath-ast/src/commonTest/kotlin/space/kscience/kmath/ast/TestCompilerVariables.kt @@ -7,7 +7,9 @@ package space.kscience.kmath.ast import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.expressions.Symbol.Companion.x +import space.kscience.kmath.expressions.Symbol.Companion.y import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.invoke import kotlin.test.Test @@ -16,11 +18,23 @@ import kotlin.test.assertFailsWith internal class TestCompilerVariables { @Test - fun testVariable() = runCompilerTest { + fun testNoVariables() = runCompilerTest { + val expr = "0".parseMath().compileToExpression(IntRing) + assertEquals(0, expr()) + } + + @Test + fun testOneVariable() = runCompilerTest { val expr = MstRing { x }.compileToExpression(IntRing) assertEquals(1, expr(x to 1)) } + @Test + fun testTwoVariables() = runCompilerTest { + val expr = "y+x/y+x".parseMath().compileToExpression(DoubleField) + assertEquals(8.0, expr(x to 4.0, y to 2.0)) + } + @Test fun testUndefinedVariableFails() = runCompilerTest { val expr = MstRing { x }.compileToExpression(IntRing) diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt index 3754c3eff..90690abed 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/internal/webassembly/lib.dom.WebAssembly.module_dukat.kt @@ -201,8 +201,8 @@ internal open external class Module { } @JsName("Instance") -internal open external class Instance(module: Module, importObject: Any = definedExternally) { - open var exports: Any +internal open external class Instance(module: Module, importObject: dynamic = definedExternally) { + open var exports: dynamic } @JsName("Memory") diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt index 622e6e6ec..2d6619bba 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/internal/WasmBuilder.kt @@ -5,12 +5,11 @@ package space.kscience.kmath.wasm.internal -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.MST.* -import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.internal.binaryen.* import space.kscience.kmath.internal.webassembly.Instance +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* import space.kscience.kmath.internal.binaryen.Module as BinaryenModule import space.kscience.kmath.internal.webassembly.Module as WasmModule @@ -18,65 +17,17 @@ import space.kscience.kmath.internal.webassembly.Module as WasmModule private val spreader = eval("(obj, args) => obj(...args)") @Suppress("UnsafeCastFromDynamic") -internal sealed class WasmBuilder( +internal sealed class WasmBuilder>( protected val binaryenType: Type, protected val algebra: Algebra, protected val target: MST, ) { protected val keys: MutableList = mutableListOf() - lateinit var ctx: BinaryenModule + protected lateinit var ctx: BinaryenModule - open fun visitSymbolic(mst: Symbol): ExpressionRef { - algebra.bindSymbolOrNull(mst)?.let { return visitNumeric(Numeric(it)) } + abstract val instance: E - var idx = keys.indexOf(mst) - - if (idx == -1) { - keys += mst - idx = keys.lastIndex - } - - return ctx.local.get(idx, binaryenType) - } - - abstract fun visitNumeric(mst: Numeric): ExpressionRef - - protected open fun visitUnary(mst: Unary): ExpressionRef = - error("Unary operation ${mst.operation} not defined in $this") - - protected open fun visitBinary(mst: Binary): ExpressionRef = - error("Binary operation ${mst.operation} not defined in $this") - - protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") - - protected fun visit(mst: MST): ExpressionRef = when (mst) { - is Symbol -> visitSymbolic(mst) - is Numeric -> visitNumeric(mst) - - is Unary -> when { - algebra is NumericAlgebra && mst.value is Numeric -> visitNumeric( - Numeric(algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as Numeric).value))) - ) - - else -> visitUnary(mst) - } - - is Binary -> when { - algebra is NumericAlgebra && mst.left is Numeric && mst.right is Numeric -> visitNumeric( - Numeric( - algebra.binaryOperationFunction(mst.operation) - .invoke( - algebra.number((mst.left as Numeric).value), - algebra.number((mst.right as Numeric).value) - ) - ) - ) - - else -> visitBinary(mst) - } - } - - val instance by lazy { + protected val executable = run { val c = WasmModule(with(createModule()) { ctx = this val expr = visit(target) @@ -97,41 +48,93 @@ internal sealed class WasmBuilder( res }) - val i = Instance(c, js("{}") as Any) - val symbols = keys - keys.clear() + Instance(c, js("{}")).exports.executable + } - Expression { args -> - val params = symbols.map(args::getValue).toTypedArray() - spreader(i.exports.asDynamic().executable, params) as T + protected open fun visitSymbol(node: Symbol): ExpressionRef { + algebra.bindSymbolOrNull(node)?.let { return visitNumeric(Numeric(it)) } + + var idx = keys.indexOf(node) + + if (idx == -1) { + keys += node + idx = keys.lastIndex + } + + return ctx.local.get(idx, binaryenType) + } + + protected abstract fun visitNumeric(node: Numeric): ExpressionRef + + protected open fun visitUnary(node: Unary): ExpressionRef = + error("Unary operation ${node.operation} not defined in $this") + + protected open fun visitBinary(mst: Binary): ExpressionRef = + error("Binary operation ${mst.operation} not defined in $this") + + protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()") + + protected fun visit(node: MST): ExpressionRef = when (node) { + is Symbol -> visitSymbol(node) + is Numeric -> visitNumeric(node) + + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> visitNumeric( + Numeric(algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as Numeric).value))) + ) + + else -> visitUnary(node) + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> visitNumeric( + Numeric( + algebra.binaryOperationFunction(node.operation) + .invoke( + algebra.number((node.left as Numeric).value), + algebra.number((node.right as Numeric).value) + ) + ) + ) + + else -> visitBinary(node) } } } -internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleField, target) { - override fun createModule(): BinaryenModule = readBinary(f64StandardFunctions) +@UnstableKMathAPI +internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleField, target) { + override val instance by lazy { + object : DoubleExpression { + override val indexer = SimpleSymbolIndexer(keys) - override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value) + override fun invoke(arguments: DoubleArray) = spreader(executable, arguments).unsafeCast() + } + } - override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { - GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) - GroupOps.PLUS_OPERATION -> visit(mst.value) - PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value)) - TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) - TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) - TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(mst.value)), f64) - TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(mst.value)), f64) - TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(mst.value)), f64) - TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(mst.value)), f64) - ExponentialOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(mst.value)), f64) - ExponentialOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(mst.value)), f64) - ExponentialOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(mst.value)), f64) - ExponentialOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(mst.value)), f64) - ExponentialOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(mst.value)), f64) - ExponentialOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(mst.value)), f64) - ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(mst.value)), f64) - ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(mst.value)), f64) - else -> super.visitUnary(mst) + override fun createModule() = readBinary(f64StandardFunctions) + + override fun visitNumeric(node: Numeric) = ctx.f64.const(node.value) + + override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { + GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(node.value)) + GroupOps.PLUS_OPERATION -> visit(node.value) + PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(node.value)) + TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(node.value)), f64) + TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(node.value)), f64) + TrigonometricOperations.TAN_OPERATION -> ctx.call("tan", arrayOf(visit(node.value)), f64) + TrigonometricOperations.ASIN_OPERATION -> ctx.call("asin", arrayOf(visit(node.value)), f64) + TrigonometricOperations.ACOS_OPERATION -> ctx.call("acos", arrayOf(visit(node.value)), f64) + TrigonometricOperations.ATAN_OPERATION -> ctx.call("atan", arrayOf(visit(node.value)), f64) + ExponentialOperations.SINH_OPERATION -> ctx.call("sinh", arrayOf(visit(node.value)), f64) + ExponentialOperations.COSH_OPERATION -> ctx.call("cosh", arrayOf(visit(node.value)), f64) + ExponentialOperations.TANH_OPERATION -> ctx.call("tanh", arrayOf(visit(node.value)), f64) + ExponentialOperations.ASINH_OPERATION -> ctx.call("asinh", arrayOf(visit(node.value)), f64) + ExponentialOperations.ACOSH_OPERATION -> ctx.call("acosh", arrayOf(visit(node.value)), f64) + ExponentialOperations.ATANH_OPERATION -> ctx.call("atanh", arrayOf(visit(node.value)), f64) + ExponentialOperations.EXP_OPERATION -> ctx.call("exp", arrayOf(visit(node.value)), f64) + ExponentialOperations.LN_OPERATION -> ctx.call("log", arrayOf(visit(node.value)), f64) + else -> super.visitUnary(node) } override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { @@ -144,13 +147,22 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder(f64, DoubleF } } -internal class IntWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, target) { - override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value) +@UnstableKMathAPI +internal class IntWasmBuilder(target: MST) : WasmBuilder(i32, IntRing, target) { + override val instance by lazy { + object : IntExpression { + override val indexer = SimpleSymbolIndexer(keys) - override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { - GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) - GroupOps.PLUS_OPERATION -> visit(mst.value) - else -> super.visitUnary(mst) + override fun invoke(arguments: IntArray) = spreader(executable, arguments).unsafeCast() + } + } + + override fun visitNumeric(node: Numeric) = ctx.i32.const(node.value) + + override fun visitUnary(node: Unary): ExpressionRef = when (node.operation) { + GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(node.value)) + GroupOps.PLUS_OPERATION -> visit(node.value) + else -> super.visitUnary(node) } override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt index 806091715..12e6b41af 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/wasm/wasm.kt @@ -3,13 +3,12 @@ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ +@file:Suppress("UNUSED_PARAMETER") + package space.kscience.kmath.wasm import space.kscience.kmath.estree.compileWith -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.MST -import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.expressions.* import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.IntRing @@ -22,7 +21,7 @@ import space.kscience.kmath.wasm.internal.IntWasmBuilder * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: IntRing): Expression = compileWith(algebra) +public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntWasmBuilder(this).instance /** @@ -50,7 +49,7 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I * @author Iaroslav Postovalov */ @UnstableKMathAPI -public fun MST.compileToExpression(algebra: DoubleField): Expression = compileWith(algebra) +public fun MST.compileToExpression(algebra: DoubleField): Expression = DoubleWasmBuilder(this).instance /** diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index c39345f3a..8e426622d 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -8,11 +8,8 @@ package space.kscience.kmath.asm import space.kscience.kmath.asm.internal.* -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.MST.* -import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.expressions.invoke import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* @@ -48,7 +45,13 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp loadVariable(node.identity) } - is Numeric -> loadNumberConstant(node.value) + is Numeric -> if (algebra is NumericAlgebra) { + if (Number::class.java.isAssignableFrom(type)) + loadNumberConstant(algebra.number(node.value) as Number) + else + loadObjectConstant(algebra.number(node.value)) + } else + error("Numeric nodes are not supported by $this") is Unary -> when { algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( @@ -121,13 +124,15 @@ public inline fun MST.compile(algebra: Algebra, vararg argu * * @author Iaroslav Postovalov */ -public fun MST.compileToExpression(algebra: IntRing): Expression = IntAsmBuilder(this).instance +@UnstableKMathAPI +public fun MST.compileToExpression(algebra: IntRing): IntExpression = IntAsmBuilder(this).instance /** * Compile given MST to expression and evaluate it against [arguments]. * * @author Iaroslav Postovalov */ +@UnstableKMathAPI public fun MST.compile(algebra: IntRing, arguments: Map): Int = compileToExpression(algebra).invoke(arguments) @@ -136,6 +141,7 @@ public fun MST.compile(algebra: IntRing, arguments: Map): Int = * * @author Iaroslav Postovalov */ +@UnstableKMathAPI public fun MST.compile(algebra: IntRing, vararg arguments: Pair): Int = compileToExpression(algebra)(*arguments) @@ -145,7 +151,8 @@ public fun MST.compile(algebra: IntRing, vararg arguments: Pair): I * * @author Iaroslav Postovalov */ -public fun MST.compileToExpression(algebra: LongRing): Expression = LongAsmBuilder(this).instance +@UnstableKMathAPI +public fun MST.compileToExpression(algebra: LongRing): LongExpression = LongAsmBuilder(this).instance /** @@ -153,6 +160,7 @@ public fun MST.compileToExpression(algebra: LongRing): Expression = LongAs * * @author Iaroslav Postovalov */ +@UnstableKMathAPI public fun MST.compile(algebra: LongRing, arguments: Map): Long = compileToExpression(algebra).invoke(arguments) @@ -162,6 +170,7 @@ public fun MST.compile(algebra: LongRing, arguments: Map): Long = * * @author Iaroslav Postovalov */ +@UnstableKMathAPI public fun MST.compile(algebra: LongRing, vararg arguments: Pair): Long = compileToExpression(algebra)(*arguments) @@ -171,13 +180,15 @@ public fun MST.compile(algebra: LongRing, vararg arguments: Pair): * * @author Iaroslav Postovalov */ -public fun MST.compileToExpression(algebra: DoubleField): Expression = DoubleAsmBuilder(this).instance +@UnstableKMathAPI +public fun MST.compileToExpression(algebra: DoubleField): DoubleExpression = DoubleAsmBuilder(this).instance /** * Compile given MST to expression and evaluate it against [arguments]. * * @author Iaroslav Postovalov */ +@UnstableKMathAPI public fun MST.compile(algebra: DoubleField, arguments: Map): Double = compileToExpression(algebra).invoke(arguments) @@ -186,5 +197,6 @@ public fun MST.compile(algebra: DoubleField, arguments: Map): Do * * @author Iaroslav Postovalov */ +@UnstableKMathAPI public fun MST.compile(algebra: DoubleField, vararg arguments: Pair): Double = compileToExpression(algebra).invoke(*arguments) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt index 7d3d513de..a85079fc8 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.asm.internal import org.objectweb.asm.Type +import org.objectweb.asm.Type.getObjectType import space.kscience.kmath.expressions.Expression internal abstract class AsmBuilder { @@ -22,31 +23,31 @@ internal abstract class AsmBuilder { /** * ASM type for [Expression]. */ - val EXPRESSION_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Expression") } + val EXPRESSION_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Expression") /** * ASM type for [java.util.Map]. */ - val MAP_TYPE: Type by lazy { Type.getObjectType("java/util/Map") } + val MAP_TYPE: Type = getObjectType("java/util/Map") /** * ASM type for [java.lang.Object]. */ - val OBJECT_TYPE: Type by lazy { Type.getObjectType("java/lang/Object") } + val OBJECT_TYPE: Type = getObjectType("java/lang/Object") /** * ASM type for [java.lang.String]. */ - val STRING_TYPE: Type by lazy { Type.getObjectType("java/lang/String") } + val STRING_TYPE: Type = getObjectType("java/lang/String") /** * ASM type for MapIntrinsics. */ - val MAP_INTRINSICS_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") } + val MAP_INTRINSICS_TYPE: Type = getObjectType("space/kscience/kmath/asm/internal/MapIntrinsics") /** * ASM Type for [space.kscience.kmath.expressions.Symbol]. */ - val SYMBOL_TYPE: Type by lazy { Type.getObjectType("space/kscience/kmath/expressions/Symbol") } + val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol") } } diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt deleted file mode 100644 index 10eee7c6b..000000000 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/ByteArrayClassLoader.kt +++ /dev/null @@ -1,6 +0,0 @@ -/* - * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.asm.internal diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt index 28f4a226f..5eb739956 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/GenericAsmBuilder.kt @@ -78,7 +78,7 @@ internal class GenericAsmBuilder( ) visitMethod( - ACC_PUBLIC or ACC_FINAL, + ACC_PUBLIC, "invoke", getMethodDescriptor(tType, MAP_TYPE), "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", @@ -116,7 +116,7 @@ internal class GenericAsmBuilder( } visitMethod( - ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, + ACC_PUBLIC or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, @@ -156,7 +156,7 @@ internal class GenericAsmBuilder( ) visitMethod( - ACC_PUBLIC, + ACC_PUBLIC or ACC_SYNTHETIC, "", getMethodDescriptor(VOID_TYPE, *OBJECT_ARRAY_TYPE.wrapToArrayIf { hasConstants }), null, @@ -176,7 +176,7 @@ internal class GenericAsmBuilder( } label() - visitInsn(RETURN) + areturn(VOID_TYPE) val l4 = label() visitLocalVariable("this", classType.descriptor, null, l0, l4, 0) @@ -209,10 +209,10 @@ internal class GenericAsmBuilder( */ fun loadObjectConstant(value: Any, type: Type = tType): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.also { it += value }.lastIndex - invokeMethodVisitor.load(0, classType) + load(0, classType) getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) - visitInsn(AALOAD) + aload(OBJECT_TYPE) if (type != OBJECT_TYPE) checkcast(type) } @@ -320,6 +320,6 @@ internal class GenericAsmBuilder( /** * ASM type for array of [java.lang.Object]. */ - val OBJECT_ARRAY_TYPE: Type by lazy { getType("[Ljava/lang/Object;") } + val OBJECT_ARRAY_TYPE: Type = getType("[Ljava/lang/Object;") } } diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt index 3e019589a..bf1f42395 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/internal/PrimitiveAsmBuilder.kt @@ -6,38 +6,49 @@ package space.kscience.kmath.asm.internal import org.objectweb.asm.ClassWriter -import org.objectweb.asm.Opcodes +import org.objectweb.asm.FieldVisitor +import org.objectweb.asm.Opcodes.* import org.objectweb.asm.Type import org.objectweb.asm.Type.* import org.objectweb.asm.commons.InstructionAdapter -import space.kscience.kmath.expressions.Expression -import space.kscience.kmath.expressions.MST -import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.expressions.* +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.* import java.lang.invoke.MethodHandles import java.lang.invoke.MethodType import java.nio.file.Paths import kotlin.io.path.writeBytes -internal sealed class PrimitiveAsmBuilder( - protected val algebra: Algebra, +@UnstableKMathAPI +internal sealed class PrimitiveAsmBuilder>( + protected val algebra: NumericAlgebra, classOfT: Class<*>, protected val classOfTPrimitive: Class<*>, + expressionParent: Class, protected val target: MST, ) : AsmBuilder() { private val className: String = buildName(target) /** - * ASM type for [T]. + * ASM type for [tType]. */ private val tType: Type = classOfT.asm /** - * ASM type for [T]. + * ASM type for [classOfTPrimitive]. */ protected val tTypePrimitive: Type = classOfTPrimitive.asm + /** + * ASM type for array of [classOfTPrimitive]. + */ + protected val tTypePrimitiveArray: Type = getType("[" + classOfTPrimitive.asm.descriptor) + + /** + * ASM type for expression parent. + */ + private val expressionParentType = expressionParent.asm + /** * ASM type for new class. */ @@ -49,58 +60,91 @@ internal sealed class PrimitiveAsmBuilder( protected lateinit var invokeMethodVisitor: InstructionAdapter /** - * Local variables indices are indices of symbols in this list. + * Indexer for arguments in [target]. */ - private val argumentsLocals = mutableListOf() + private val argumentsIndexer = mutableListOf() /** * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ - @Suppress("UNCHECKED_CAST", "UNUSED_VARIABLE") - val instance: Expression by lazy { + @Suppress("UNCHECKED_CAST") + val instance: E by lazy { val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( - Opcodes.V1_8, - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, + V1_8, + ACC_PUBLIC or ACC_FINAL or ACC_SUPER, classType.internalName, - "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + "${OBJECT_TYPE.descriptor}${expressionParentType.descriptor}", OBJECT_TYPE.internalName, - arrayOf(EXPRESSION_TYPE.internalName), + arrayOf(expressionParentType.internalName), ) + visitField( + access = ACC_PRIVATE or ACC_FINAL, + name = "indexer", + descriptor = SYMBOL_INDEXER_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd, + ) visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, - "invoke", - getMethodDescriptor(tType, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", + ACC_PUBLIC, + "getIndexer", + getMethodDescriptor(SYMBOL_INDEXER_TYPE), + null, null, ).instructionAdapter { - invokeMethodVisitor = this visitCode() - val preparingVariables = label() - visitVariables(target) - val expressionResult = label() - visitExpression(target) - box() - areturn(tType) + val start = label() + load(0, classType) + getfield(classType.internalName, "indexer", SYMBOL_INDEXER_TYPE.descriptor) + areturn(SYMBOL_INDEXER_TYPE) val end = label() visitLocalVariable( "this", classType.descriptor, null, - preparingVariables, + start, + end, + 0, + ) + + visitMaxs(0, 0) + visitEnd() + } + + visitMethod( + ACC_PUBLIC, + "invoke", + getMethodDescriptor(tTypePrimitive, tTypePrimitiveArray), + null, + null, + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val start = label() + visitVariables(target, arrayMode = true) + visitExpression(target) + areturn(tTypePrimitive) + val end = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + start, end, 0, ) visitLocalVariable( "arguments", - MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", - preparingVariables, + tTypePrimitiveArray.descriptor, + null, + start, end, 1, ) @@ -110,7 +154,45 @@ internal sealed class PrimitiveAsmBuilder( } visitMethod( - Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_BRIDGE or Opcodes.ACC_SYNTHETIC, + ACC_PUBLIC or ACC_FINAL, + "invoke", + getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", + null, + ).instructionAdapter { + invokeMethodVisitor = this + visitCode() + val start = label() + visitVariables(target, arrayMode = false) + visitExpression(target) + box() + areturn(tType) + val end = label() + + visitLocalVariable( + "this", + classType.descriptor, + null, + start, + end, + 0, + ) + + visitLocalVariable( + "arguments", + MAP_TYPE.descriptor, + "L${MAP_TYPE.internalName}<${SYMBOL_TYPE.descriptor}+${tType.descriptor}>;", + start, + end, + 1, + ) + + visitMaxs(0, 0) + visitEnd() + } + + visitMethod( + ACC_PUBLIC or ACC_FINAL or ACC_BRIDGE or ACC_SYNTHETIC, "invoke", getMethodDescriptor(OBJECT_TYPE, MAP_TYPE), null, @@ -138,21 +220,22 @@ internal sealed class PrimitiveAsmBuilder( } visitMethod( - Opcodes.ACC_PUBLIC, + ACC_PUBLIC or ACC_SYNTHETIC, "", - getMethodDescriptor(VOID_TYPE), + getMethodDescriptor(VOID_TYPE, SYMBOL_INDEXER_TYPE), null, null, ).instructionAdapter { val start = label() load(0, classType) invokespecial(OBJECT_TYPE.internalName, "", getMethodDescriptor(VOID_TYPE), false) - label() load(0, classType) - label() - visitInsn(Opcodes.RETURN) + load(1, SYMBOL_INDEXER_TYPE) + putfield(classType.internalName, "indexer", SYMBOL_INDEXER_TYPE.descriptor) + areturn(VOID_TYPE) val end = label() visitLocalVariable("this", classType.descriptor, null, start, end, 0) + visitLocalVariable("indexer", SYMBOL_INDEXER_TYPE.descriptor, null, start, end, 1) visitMaxs(0, 0) visitEnd() } @@ -166,14 +249,16 @@ internal sealed class PrimitiveAsmBuilder( if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1") Paths.get("${className.split('.').last()}.class").writeBytes(binary) - MethodHandles.publicLookup().findConstructor(cls, MethodType.methodType(Void.TYPE))() as Expression + MethodHandles + .publicLookup() + .findConstructor(cls, MethodType.methodType(Void.TYPE, SymbolIndexer::class.java)) + .invoke(SimpleSymbolIndexer(argumentsIndexer)) as E } /** - * Either loads a numeric constant [value] from the class's constants field or boxes a primitive - * constant from the constant pool. + * Loads a numeric constant [value] from the class's constants. */ - fun loadNumberConstant(value: Number) { + protected fun loadNumberConstant(value: Number) { when (tTypePrimitive) { BYTE_TYPE -> invokeMethodVisitor.iconst(value.toInt()) DOUBLE_TYPE -> invokeMethodVisitor.dconst(value.toDouble()) @@ -185,38 +270,50 @@ internal sealed class PrimitiveAsmBuilder( } /** - * Stores value variable [name] into a local. Should be called within [variablesPrepareCallback] before using - * [loadVariable]. + * Stores value variable [name] into a local. Should be called before using [loadVariable]. Should be called only + * once for a variable. */ - fun prepareVariable(name: String): Unit = invokeMethodVisitor.run { - if (name in argumentsLocals) return@run - load(1, MAP_TYPE) - aconst(name) + protected fun prepareVariable(name: Symbol, arrayMode: Boolean): Unit = invokeMethodVisitor.run { + var argumentIndex = argumentsIndexer.indexOf(name) - invokestatic( - MAP_INTRINSICS_TYPE.internalName, - "getOrFail", - getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), - false, - ) - - checkcast(tType) - var idx = argumentsLocals.indexOf(name) - - if (idx == -1) { - argumentsLocals += name - idx = argumentsLocals.lastIndex + if (argumentIndex == -1) { + argumentsIndexer += name + argumentIndex = argumentsIndexer.lastIndex } - unbox() - store(2 + idx, tTypePrimitive) + val localIndex = 2 + argumentIndex * tTypePrimitive.size + + if (arrayMode) { + load(1, tTypePrimitiveArray) + iconst(argumentIndex) + aload(tTypePrimitive) + store(localIndex, tTypePrimitive) + } else { + load(1, MAP_TYPE) + aconst(name.identity) + + invokestatic( + MAP_INTRINSICS_TYPE.internalName, + "getOrFail", + getMethodDescriptor(OBJECT_TYPE, MAP_TYPE, STRING_TYPE), + false, + ) + + checkcast(tType) + unbox() + store(localIndex, tTypePrimitive) + } } /** * Loads a variable [name] from arguments [Map] parameter of [Expression.invoke]. The variable should be stored * with [prepareVariable] first. */ - fun loadVariable(name: String): Unit = invokeMethodVisitor.load(2 + argumentsLocals.indexOf(name), tTypePrimitive) + protected fun loadVariable(name: Symbol) { + val argumentIndex = argumentsIndexer.indexOf(name) + val localIndex = 2 + argumentIndex * tTypePrimitive.size + invokeMethodVisitor.load(localIndex, tTypePrimitive) + } private fun unbox() = invokeMethodVisitor.run { invokevirtual( @@ -231,102 +328,117 @@ internal sealed class PrimitiveAsmBuilder( invokestatic(tType.internalName, "valueOf", getMethodDescriptor(tType, tTypePrimitive), false) } - protected fun visitVariables(node: MST): Unit = when (node) { - is Symbol -> prepareVariable(node.identity) - is MST.Unary -> visitVariables(node.value) + private fun visitVariables( + node: MST, + arrayMode: Boolean, + alreadyLoaded: MutableList = mutableListOf() + ): Unit = when (node) { + is Symbol -> when (node) { + !in alreadyLoaded -> { + alreadyLoaded += node + prepareVariable(node, arrayMode) + } + else -> { + } + } + + is MST.Unary -> visitVariables(node.value, arrayMode, alreadyLoaded) is MST.Binary -> { - visitVariables(node.left) - visitVariables(node.right) + visitVariables(node.left, arrayMode, alreadyLoaded) + visitVariables(node.right, arrayMode, alreadyLoaded) } else -> Unit } - protected fun visitExpression(mst: MST): Unit = when (mst) { - is Symbol -> loadVariable(mst.identity) - is MST.Numeric -> loadNumberConstant(mst.value) + private fun visitExpression(node: MST): Unit = when (node) { + is Symbol -> { + val symbol = algebra.bindSymbolOrNull(node) - is MST.Unary -> when { - algebra is NumericAlgebra && mst.value is MST.Numeric -> { - loadNumberConstant( - MST.Numeric( - algebra.unaryOperationFunction(mst.operation)(algebra.number((mst.value as MST.Numeric).value)), - ).value, - ) - } - - else -> visitUnary(mst) + if (symbol != null) + loadNumberConstant(symbol) + else + loadVariable(node) } + is MST.Numeric -> loadNumberConstant(algebra.number(node.value)) + + is MST.Unary -> if (node.value is MST.Numeric) + loadNumberConstant( + algebra.unaryOperationFunction(node.operation)(algebra.number((node.value as MST.Numeric).value)), + ) + else + visitUnary(node) + is MST.Binary -> when { - algebra is NumericAlgebra && mst.left is MST.Numeric && mst.right is MST.Numeric -> { - loadNumberConstant( - MST.Numeric( - algebra.binaryOperationFunction(mst.operation)( - algebra.number((mst.left as MST.Numeric).value), - algebra.number((mst.right as MST.Numeric).value), - ), - ).value, - ) - } + node.left is MST.Numeric && node.right is MST.Numeric -> loadNumberConstant( + algebra.binaryOperationFunction(node.operation)( + algebra.number((node.left as MST.Numeric).value), + algebra.number((node.right as MST.Numeric).value), + ), + ) - else -> visitBinary(mst) + else -> visitBinary(node) } } - protected open fun visitUnary(mst: MST.Unary) { - visitExpression(mst.value) - } + protected open fun visitUnary(node: MST.Unary) = visitExpression(node.value) - protected open fun visitBinary(mst: MST.Binary) { - visitExpression(mst.left) - visitExpression(mst.right) + protected open fun visitBinary(node: MST.Binary) { + visitExpression(node.left) + visitExpression(node.right) } protected companion object { /** * ASM type for [java.lang.Number]. */ - val NUMBER_TYPE: Type by lazy { getObjectType("java/lang/Number") } + val NUMBER_TYPE: Type = getObjectType("java/lang/Number") + + /** + * ASM type for [SymbolIndexer]. + */ + val SYMBOL_INDEXER_TYPE: Type = getObjectType("space/kscience/kmath/expressions/SymbolIndexer") } } -internal class DoubleAsmBuilder(target: MST) : - PrimitiveAsmBuilder(DoubleField, java.lang.Double::class.java, java.lang.Double.TYPE, target) { +@UnstableKMathAPI +internal class DoubleAsmBuilder(target: MST) : PrimitiveAsmBuilder( + DoubleField, + java.lang.Double::class.java, + java.lang.Double.TYPE, + DoubleExpression::class.java, + target, +) { - private fun buildUnaryJavaMathCall(name: String) { - invokeMethodVisitor.invokestatic( - MATH_TYPE.internalName, - name, - getMethodDescriptor(tTypePrimitive, tTypePrimitive), - false, - ) - } + private fun buildUnaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic( + MATH_TYPE.internalName, + name, + getMethodDescriptor(tTypePrimitive, tTypePrimitive), + false, + ) - private fun buildBinaryJavaMathCall(name: String) { - invokeMethodVisitor.invokestatic( - MATH_TYPE.internalName, - name, - getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive), - false, - ) - } + @Suppress("SameParameterValue") + private fun buildBinaryJavaMathCall(name: String) = invokeMethodVisitor.invokestatic( + MATH_TYPE.internalName, + name, + getMethodDescriptor(tTypePrimitive, tTypePrimitive, tTypePrimitive), + false, + ) - private fun buildUnaryKotlinMathCall(name: String) { - invokeMethodVisitor.invokestatic( - MATH_KT_TYPE.internalName, - name, - getMethodDescriptor(tTypePrimitive, tTypePrimitive), - false, - ) - } + private fun buildUnaryKotlinMathCall(name: String) = invokeMethodVisitor.invokestatic( + MATH_KT_TYPE.internalName, + name, + getMethodDescriptor(tTypePrimitive, tTypePrimitive), + false, + ) - override fun visitUnary(mst: MST.Unary) { - super.visitUnary(mst) + override fun visitUnary(node: MST.Unary) { + super.visitUnary(node) - when (mst.operation) { - GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DNEG) + when (node.operation) { + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(DNEG) GroupOps.PLUS_OPERATION -> Unit PowerOperations.SQRT_OPERATION -> buildUnaryJavaMathCall("sqrt") TrigonometricOperations.SIN_OPERATION -> buildUnaryJavaMathCall("sin") @@ -343,74 +455,86 @@ internal class DoubleAsmBuilder(target: MST) : ExponentialOperations.ATANH_OPERATION -> buildUnaryKotlinMathCall("atanh") ExponentialOperations.EXP_OPERATION -> buildUnaryJavaMathCall("exp") ExponentialOperations.LN_OPERATION -> buildUnaryJavaMathCall("log") - else -> super.visitUnary(mst) + else -> super.visitUnary(node) } } - override fun visitBinary(mst: MST.Binary) { - super.visitBinary(mst) + override fun visitBinary(node: MST.Binary) { + super.visitBinary(node) - when (mst.operation) { - GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DADD) - GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DSUB) - RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DMUL) - FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.DDIV) + when (node.operation) { + GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(DADD) + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(DSUB) + RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(DMUL) + FieldOps.DIV_OPERATION -> invokeMethodVisitor.visitInsn(DDIV) PowerOperations.POW_OPERATION -> buildBinaryJavaMathCall("pow") - else -> super.visitBinary(mst) + else -> super.visitBinary(node) } } - companion object { - val MATH_TYPE: Type by lazy { getObjectType("java/lang/Math") } - val MATH_KT_TYPE: Type by lazy { getObjectType("kotlin/math/MathKt") } + private companion object { + val MATH_TYPE: Type = getObjectType("java/lang/Math") + val MATH_KT_TYPE: Type = getObjectType("kotlin/math/MathKt") } } +@UnstableKMathAPI internal class IntAsmBuilder(target: MST) : - PrimitiveAsmBuilder(IntRing, Integer::class.java, Integer.TYPE, target) { - override fun visitUnary(mst: MST.Unary) { - super.visitUnary(mst) + PrimitiveAsmBuilder( + IntRing, + Integer::class.java, + Integer.TYPE, + IntExpression::class.java, + target + ) { + override fun visitUnary(node: MST.Unary) { + super.visitUnary(node) - when (mst.operation) { - GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.INEG) + when (node.operation) { + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(INEG) GroupOps.PLUS_OPERATION -> Unit - else -> super.visitUnary(mst) + else -> super.visitUnary(node) } } - override fun visitBinary(mst: MST.Binary) { - super.visitBinary(mst) + override fun visitBinary(node: MST.Binary) { + super.visitBinary(node) - when (mst.operation) { - GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IADD) - GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.ISUB) - RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.IMUL) - else -> super.visitBinary(mst) + when (node.operation) { + GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(IADD) + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(ISUB) + RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(IMUL) + else -> super.visitBinary(node) } } } -internal class LongAsmBuilder(target: MST) : - PrimitiveAsmBuilder(LongRing, java.lang.Long::class.java, java.lang.Long.TYPE, target) { - override fun visitUnary(mst: MST.Unary) { - super.visitUnary(mst) +@UnstableKMathAPI +internal class LongAsmBuilder(target: MST) : PrimitiveAsmBuilder( + LongRing, + java.lang.Long::class.java, + java.lang.Long.TYPE, + LongExpression::class.java, + target, +) { + override fun visitUnary(node: MST.Unary) { + super.visitUnary(node) - when (mst.operation) { - GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LNEG) + when (node.operation) { + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(LNEG) GroupOps.PLUS_OPERATION -> Unit - else -> super.visitUnary(mst) + else -> super.visitUnary(node) } } - override fun visitBinary(mst: MST.Binary) { - super.visitBinary(mst) + override fun visitBinary(node: MST.Binary) { + super.visitBinary(node) - when (mst.operation) { - GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LADD) - GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LSUB) - RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(Opcodes.LMUL) - else -> super.visitBinary(mst) + when (node.operation) { + GroupOps.PLUS_OPERATION -> invokeMethodVisitor.visitInsn(LADD) + GroupOps.MINUS_OPERATION -> invokeMethodVisitor.visitInsn(LSUB) + RingOps.TIMES_OPERATION -> invokeMethodVisitor.visitInsn(LMUL) + else -> super.visitBinary(node) } } } - diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index 490868a0b..4a35a54fb 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -8,7 +8,10 @@ plugins { kotlin.sourceSets { filter { it.name.contains("test", true) } .map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings) - .forEach { it.optIn("space.kscience.kmath.misc.PerformancePitfall") } + .forEach { + it.optIn("space.kscience.kmath.misc.PerformancePitfall") + it.optIn("space.kscience.kmath.misc.UnstableKMathAPI") + } commonMain { dependencies { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt index 33f53202c..5ba32f190 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.expressions +import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.Algebra import kotlin.jvm.JvmName import kotlin.properties.ReadOnlyProperty @@ -24,6 +25,81 @@ public fun interface Expression { public operator fun invoke(arguments: Map): T } +/** + * Specialization of [Expression] for [Double] allowing better performance because of using array. + */ +@UnstableKMathAPI +public interface DoubleExpression : Expression { + /** + * The indexer of this expression's arguments that should be used to build array for [invoke]. + * + * Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`, + * `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke]. + */ + public val indexer: SymbolIndexer + + public override operator fun invoke(arguments: Map): Double = + this(DoubleArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) }) + + /** + * Calls this expression from arguments. + * + * @param arguments the array of arguments. + * @return the value. + */ + public operator fun invoke(arguments: DoubleArray): Double +} + +/** + * Specialization of [Expression] for [Int] allowing better performance because of using array. + */ +@UnstableKMathAPI +public interface IntExpression : Expression { + /** + * The indexer of this expression's arguments that should be used to build array for [invoke]. + * + * Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`, + * `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke]. + */ + public val indexer: SymbolIndexer + + public override operator fun invoke(arguments: Map): Int = + this(IntArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) }) + + /** + * Calls this expression from arguments. + * + * @param arguments the array of arguments. + * @return the value. + */ + public operator fun invoke(arguments: IntArray): Int +} + +/** + * Specialization of [Expression] for [Long] allowing better performance because of using array. + */ +@UnstableKMathAPI +public interface LongExpression : Expression { + /** + * The indexer of this expression's arguments that should be used to build array for [invoke]. + * + * Implementations must fulfil the following requirement: for any argument symbol `x` and its value `y`, + * `indexer.indexOf(x) == arguments.indexOf(y)` if `arguments` is the array passed to [invoke]. + */ + public val indexer: SymbolIndexer + + public override operator fun invoke(arguments: Map): Long = + this(LongArray(indexer.symbols.size) { arguments.getValue(indexer.symbols[it]) }) + + /** + * Calls this expression from arguments. + * + * @param arguments the array of arguments. + * @return the value. + */ + public operator fun invoke(arguments: LongArray): Long +} + /** * Calls this expression without providing any arguments. * @@ -69,6 +145,62 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = } ) +private val EMPTY_DOUBLE_ARRAY = DoubleArray(0) + +/** + * Calls this expression without providing any arguments. + * + * @return a value. + */ +@UnstableKMathAPI +public operator fun DoubleExpression.invoke(): Double = this(EMPTY_DOUBLE_ARRAY) + +/** + * Calls this expression from arguments. + * + * @param pairs the pairs of arguments to values. + * @return a value. + */ +@UnstableKMathAPI +public operator fun DoubleExpression.invoke(vararg arguments: Double): Double = this(arguments) + +private val EMPTY_INT_ARRAY = IntArray(0) + +/** + * Calls this expression without providing any arguments. + * + * @return a value. + */ +@UnstableKMathAPI +public operator fun IntExpression.invoke(): Int = this(EMPTY_INT_ARRAY) + +/** + * Calls this expression from arguments. + * + * @param pairs the pairs of arguments to values. + * @return a value. + */ +@UnstableKMathAPI +public operator fun IntExpression.invoke(vararg arguments: Int): Int = this(arguments) + +private val EMPTY_LONG_ARRAY = LongArray(0) + +/** + * Calls this expression without providing any arguments. + * + * @return a value. + */ +@UnstableKMathAPI +public operator fun LongExpression.invoke(): Long = this(EMPTY_LONG_ARRAY) + +/** + * Calls this expression from arguments. + * + * @param pairs the pairs of arguments to values. + * @return a value. + */ +@UnstableKMathAPI +public operator fun LongExpression.invoke(vararg arguments: Long): Long = this(arguments) /** * A context for expression construction diff --git a/kmath-kotlingrad/build.gradle.kts b/kmath-kotlingrad/build.gradle.kts index d222ed7d6..da7380e41 100644 --- a/kmath-kotlingrad/build.gradle.kts +++ b/kmath-kotlingrad/build.gradle.kts @@ -3,6 +3,11 @@ plugins { id("ru.mipt.npm.gradle.common") } +kotlin.sourceSets + .filter { it.name.contains("test", true) } + .map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings) + .forEach { it.optIn("space.kscience.kmath.misc.UnstableKMathAPI") } + description = "Kotlin∇ integration module" dependencies { diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt index 6e7eb039d..63db1c56f 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt @@ -145,7 +145,7 @@ internal object InternalGamma { } when { - n >= maxIterations -> throw error("Maximal iterations is exceeded $maxIterations") + n >= maxIterations -> error("Maximal iterations is exceeded $maxIterations") sum.isInfinite() -> 1.0 else -> exp(-x + a * ln(x) - logGamma(a)) * sum }