From 9ee506b1d2996b3d46ea569e03d3a1862e4c1dff Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 25 Mar 2021 23:57:47 +0700 Subject: [PATCH 1/9] Some experiments with MST rendering --- README.md | 4 +- .../space/kscience/kmath/ast/astRendering.kt | 21 ++ kmath-ast/README.md | 42 ++- kmath-ast/docs/README-TEMPLATE.md | 36 ++ .../ast/rendering/LatexSyntaxRenderer.kt | 128 +++++++ .../ast/rendering/MathMLSyntaxRenderer.kt | 133 +++++++ .../kmath/ast/rendering/MathRenderer.kt | 102 ++++++ .../kmath/ast/rendering/MathSyntax.kt | 326 ++++++++++++++++++ .../kmath/ast/rendering/SyntaxRenderer.kt | 25 ++ .../kscience/kmath/ast/rendering/features.kt | 312 +++++++++++++++++ .../kscience/kmath/ast/rendering/stages.kt | 197 +++++++++++ .../space/kscience/kmath/estree/estree.kt | 4 +- .../kotlin/space/kscience/kmath/asm/asm.kt | 4 +- .../kotlin/space/kscience/kmath/ast/parser.kt | 3 +- .../kmath/ast/rendering/TestFeatures.kt | 90 +++++ .../kscience/kmath/ast/rendering/TestLatex.kt | 65 ++++ .../kmath/ast/rendering/TestMathML.kt | 84 +++++ .../kmath/ast/rendering/TestStages.kt | 28 ++ .../kscience/kmath/ast/rendering/TestUtils.kt | 41 +++ kmath-complex/README.md | 6 +- kmath-core/README.md | 6 +- .../space/kscience/kmath/operations/BigInt.kt | 3 +- kmath-ejml/README.md | 6 +- kmath-for-real/README.md | 6 +- kmath-functions/README.md | 6 +- kmath-nd4j/README.md | 6 +- 26 files changed, 1655 insertions(+), 29 deletions(-) create mode 100644 examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt create mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/stages.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt create mode 100644 kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt diff --git a/README.md b/README.md index 7080c757e..7b78d4531 100644 --- a/README.md +++ b/README.md @@ -259,8 +259,8 @@ repositories { } dependencies { - api("space.kscience:kmath-core:0.3.0-dev-3") - // api("kscience.kmath:kmath-core-jvm:0.3.0-dev-3") for jvm-specific version + api("space.kscience:kmath-core:0.3.0-dev-4") + // api("kscience.kmath:kmath-core-jvm:0.3.0-dev-4") for jvm-specific version } ``` diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt new file mode 100644 index 000000000..a250ad800 --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/astRendering.kt @@ -0,0 +1,21 @@ +package space.kscience.kmath.ast + +import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess +import space.kscience.kmath.ast.rendering.LatexSyntaxRenderer +import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer +import space.kscience.kmath.ast.rendering.renderWithStringBuilder + +public fun main() { + val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath() + val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst) + println("MathSyntax:") + println(syntax) + println() + val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax) + println("LaTeX:") + println(latex) + println() + val mathML = MathMLSyntaxRenderer.renderWithStringBuilder(syntax) + println("MathML:") + println(mathML) +} diff --git a/kmath-ast/README.md b/kmath-ast/README.md index ff954b914..44faa5cd5 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -12,7 +12,7 @@ Abstract syntax tree expression representation and related optimizations. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-4`. **Gradle:** ```gradle @@ -23,7 +23,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-ast:0.3.0-dev-3' + implementation 'space.kscience:kmath-ast:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -35,7 +35,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-ast:0.3.0-dev-3") + implementation("space.kscience:kmath-ast:0.3.0-dev-4") } ``` @@ -111,3 +111,39 @@ var executable = function (constants, arguments) { #### Known issues - This feature uses `eval` which can be unavailable in several environments. + +## Rendering expressions + +kmath-ast also includes an extensible engine to display expressions in LaTeX or MathML syntax. + +Example usage: + +```kotlin +import space.kscience.kmath.ast.* +import space.kscience.kmath.ast.rendering.* + +public fun main() { + val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath() + val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst) + val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax) + println("LaTeX:") + println(latex) + println() + val mathML = MathMLSyntaxRenderer.renderWithStringBuilder(syntax) + println("MathML:") + println(mathML) +} +``` + +Result LaTeX: + +![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D) + +Result MathML (embedding MathML is not allowed by GitHub Markdown): + +```html +ex-sin-12x2×1010+x3-12 +``` + +It is also possible to create custom algorithms of render, and even add support of other markup languages +(see API reference). diff --git a/kmath-ast/docs/README-TEMPLATE.md b/kmath-ast/docs/README-TEMPLATE.md index db071adb4..9ed44d584 100644 --- a/kmath-ast/docs/README-TEMPLATE.md +++ b/kmath-ast/docs/README-TEMPLATE.md @@ -78,3 +78,39 @@ var executable = function (constants, arguments) { #### Known issues - This feature uses `eval` which can be unavailable in several environments. + +## Rendering expressions + +kmath-ast also includes an extensible engine to display expressions in LaTeX or MathML syntax. + +Example usage: + +```kotlin +import space.kscience.kmath.ast.* +import space.kscience.kmath.ast.rendering.* + +public fun main() { + val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath() + val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst) + val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax) + println("LaTeX:") + println(latex) + println() + val mathML = MathMLSyntaxRenderer.renderWithStringBuilder(syntax) + println("MathML:") + println(mathML) +} +``` + +Result LaTeX: + +![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D) + +Result MathML (embedding MathML is not allowed by GitHub Markdown): + +```html +ex-sin-12x2×1010+x3-12 +``` + +It is also possible to create custom algorithms of render, and even add support of other markup languages +(see API reference). diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt new file mode 100644 index 000000000..914da6d9f --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt @@ -0,0 +1,128 @@ +package space.kscience.kmath.ast.rendering + +/** + * [SyntaxRenderer] implementation for LaTeX. + * + * The generated string is a valid LaTeX fragment to be used in the Math Mode. + * + * Example usage: + * + * ``` + * \documentclass{article} + * \begin{document} + * \begin{equation} + * %code generated by the syntax renderer + * \end{equation} + * \end{document} + * ``` + * + * @author Iaroslav Postovalov + */ +public object LatexSyntaxRenderer : SyntaxRenderer { + public override fun render(node: MathSyntax, output: Appendable): Unit = output.run { + fun render(syntax: MathSyntax) = render(syntax, output) + + when (node) { + is NumberSyntax -> append(node.string) + is SymbolSyntax -> append(node.string) + + is OperatorNameSyntax -> { + append("\\operatorname{") + append(node.name) + append('}') + } + + is SpecialSymbolSyntax -> when (node.kind) { + SpecialSymbolSyntax.Kind.INFINITY -> append("\\infty") + } + + is OperandSyntax -> { + if (node.parentheses) append("\\left(") + render(node.operand) + if (node.parentheses) append("\\right)") + } + + is UnaryOperatorSyntax -> { + render(node.prefix) + append("\\,") + render(node.operand) + } + + is UnaryPlusSyntax -> { + append('+') + render(node.operand) + } + + is UnaryMinusSyntax -> { + append('-') + render(node.operand) + } + + is RadicalSyntax -> { + append("\\sqrt") + append('{') + render(node.operand) + append('}') + } + + is SuperscriptSyntax -> { + render(node.left) + append("^{") + render(node.right) + append('}') + } + + is SubscriptSyntax -> { + render(node.left) + append("_{") + render(node.right) + append('}') + } + + is BinaryOperatorSyntax -> { + render(node.prefix) + append("\\left(") + render(node.left) + append(',') + render(node.right) + append("\\right)") + } + + is BinaryPlusSyntax -> { + render(node.left) + append('+') + render(node.right) + } + + is BinaryMinusSyntax -> { + render(node.left) + append('-') + render(node.right) + } + + is FractionSyntax -> { + append("\\frac{") + render(node.left) + append("}{") + render(node.right) + append('}') + } + + is RadicalWithIndexSyntax -> { + append("\\sqrt") + append('[') + render(node.left) + append(']') + append('{') + render(node.right) + append('}') + } + + is MultiplicationSyntax -> { + render(node.left) + append(if (node.times) "\\times" else "\\,") + render(node.right) + } + } + } +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt new file mode 100644 index 000000000..6f194be86 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt @@ -0,0 +1,133 @@ +package space.kscience.kmath.ast.rendering + +/** + * [SyntaxRenderer] implementation for MathML. + * + * The generated XML string is a valid MathML instance. + * + * @author Iaroslav Postovalov + */ +public object MathMLSyntaxRenderer : SyntaxRenderer { + public override fun render(node: MathSyntax, output: Appendable) { + output.append("") + render0(node, output) + output.append("") + } + + private fun render0(node: MathSyntax, output: Appendable): Unit = output.run { + fun tag(tagName: String, vararg attr: Pair, block: () -> Unit = {}) { + append('<') + append(tagName) + + if (attr.isNotEmpty()) { + append(' ') + var count = 0 + + for ((name, value) in attr) { + if (++count > 1) append(' ') + append(name) + append("=\"") + append(value) + append('"') + } + } + + append('>') + block() + append("') + } + + fun render(syntax: MathSyntax) = render0(syntax, output) + + when (node) { + is NumberSyntax -> tag("mn") { append(node.string) } + is SymbolSyntax -> tag("mi") { append(node.string) } + is OperatorNameSyntax -> tag("mo") { append(node.name) } + + is SpecialSymbolSyntax -> when (node.kind) { + SpecialSymbolSyntax.Kind.INFINITY -> tag("mo") { append("∞") } + } + + is OperandSyntax -> if (node.parentheses) { + tag("mfenced", "open" to "(", "close" to ")", "separators" to "") { + render(node.operand) + } + } else { + render(node.operand) + } + + is UnaryOperatorSyntax -> { + render(node.prefix) + tag("mspace", "width" to "0.167em") + render(node.operand) + } + + is UnaryPlusSyntax -> { + tag("mo") { append('+') } + render(node.operand) + } + + is UnaryMinusSyntax -> { + tag("mo") { append("-") } + render(node.operand) + } + + is RadicalSyntax -> tag("msqrt") { render(node.operand) } + + is SuperscriptSyntax -> tag("msup") { + tag("mrow") { render(node.left) } + tag("mrow") { render(node.right) } + } + + is SubscriptSyntax -> tag("msub") { + tag("mrow") { render(node.left) } + tag("mrow") { render(node.right) } + } + + is BinaryOperatorSyntax -> { + render(node.prefix) + + tag("mfenced", "open" to "(", "close" to ")", "separators" to "") { + render(node.left) + tag("mo") { append(',') } + render(node.right) + } + } + + is BinaryPlusSyntax -> { + render(node.left) + tag("mo") { append('+') } + render(node.right) + } + + is BinaryMinusSyntax -> { + render(node.left) + tag("mo") { append('-') } + render(node.right) + } + + is FractionSyntax -> tag("mfrac") { + tag("mrow") { + render(node.left) + } + + tag("mrow") { + render(node.right) + } + } + + is RadicalWithIndexSyntax -> tag("mroot") { + tag("mrow") { render(node.right) } + tag("mrow") { render(node.left) } + } + + is MultiplicationSyntax -> { + render(node.left) + if (node.times) tag("mo") { append("×") } else tag("mspace", "width" to "0.167em") + render(node.right) + } + } + } +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt new file mode 100644 index 000000000..afdf12b04 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt @@ -0,0 +1,102 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.MST + +/** + * Renders [MST] to [MathSyntax]. + * + * @author Iaroslav Postovalov + */ +public fun interface MathRenderer { + /** + * Renders [MST] to [MathSyntax]. + */ + public fun render(mst: MST): MathSyntax +} + +/** + * Implements [MST] render process with sequence of features. + * + * @property features The applied features. + * @author Iaroslav Postovalov + */ +public open class FeaturedMathRenderer(public val features: List) : MathRenderer { + public override fun render(mst: MST): MathSyntax { + for (feature in features) feature.render(this, mst)?.let { return it } + throw UnsupportedOperationException("Renderer $this has no appropriate feature to render node $mst.") + } + + /** + * Logical unit of [MST] rendering. + */ + public fun interface RenderFeature { + /** + * Renders [MST] to [MathSyntax] in the context of owning renderer. + */ + public fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? + } +} + +/** + * Extends [FeaturedMathRenderer] by adding post-processing stages. + * + * @property stages The applied stages. + * @author Iaroslav Postovalov + */ +public open class FeaturedMathRendererWithPostProcess( + features: List, + public val stages: List, +) : FeaturedMathRenderer(features) { + public override fun render(mst: MST): MathSyntax { + val res = super.render(mst) + for (stage in stages) stage.perform(res) + return res + } + + /** + * Logical unit of [MathSyntax] post-processing. + */ + public fun interface PostProcessStage { + /** + * Performs the specified action over [MathSyntax]. + */ + public fun perform(node: MathSyntax) + } + + public companion object { + /** + * The default setup of [FeaturedMathRendererWithPostProcess]. + */ + public val Default: FeaturedMathRendererWithPostProcess = FeaturedMathRendererWithPostProcess( + listOf( + // Printing known operations + BinaryPlus.Default, + BinaryMinus.Default, + UnaryPlus.Default, + UnaryMinus.Default, + Multiplication.Default, + Fraction.Default, + Power.Default, + SquareRoot.Default, + Exponential.Default, + InverseTrigonometricOperations.Default, + + // Fallback option for unknown operations - printing them as operator + BinaryOperator.Default, + UnaryOperator.Default, + + // Pretty printing for numerics + PrettyPrintFloats.Default, + PrettyPrintIntegers.Default, + + // Printing terminal nodes as string + PrintNumeric, + PrintSymbolic, + ), + listOf( + SimplifyParentheses.Default, + BetterMultiplication, + ) + ) + } +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt new file mode 100644 index 000000000..4c85adcfc --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt @@ -0,0 +1,326 @@ +package space.kscience.kmath.ast.rendering + +/** + * Mathematical typography syntax node. + * + * @author Iaroslav Postovalov + */ +public sealed class MathSyntax { + /** + * The parent node of this syntax node. + */ + public var parent: MathSyntax? = null +} + +/** + * Terminal node, which should not have any children nodes. + * + * @author Iaroslav Postovalov + */ +public sealed class TerminalSyntax : MathSyntax() + +/** + * Node containing a certain operation. + * + * @author Iaroslav Postovalov + */ +public sealed class OperationSyntax : MathSyntax() { + /** + * The operation token. + */ + public abstract val operation: String +} + +/** + * Unary node, which has only one child. + * + * @author Iaroslav Postovalov + */ +public sealed class UnarySyntax : OperationSyntax() { + /** + * The operand of this node. + */ + public abstract val operand: MathSyntax +} + +/** + * Binary node, which has only two children. + * + * @author Iaroslav Postovalov + */ +public sealed class BinarySyntax : OperationSyntax() { + /** + * The left-hand side operand. + */ + public abstract val left: MathSyntax + + /** + * The right-hand side operand. + */ + public abstract val right: MathSyntax +} + +/** + * Represents a number. + * + * @property string The digits of number. + * @author Iaroslav Postovalov + */ +public data class NumberSyntax(public var string: String) : TerminalSyntax() + +/** + * Represents a symbol. + * + * @property string The symbol. + * @author Iaroslav Postovalov + */ +public data class SymbolSyntax(public var string: String) : TerminalSyntax() + +/** + * Represents special typing for operator name. + * + * @property name The operator name. + * @see BinaryOperatorSyntax + * @see UnaryOperatorSyntax + * @author Iaroslav Postovalov + */ +public data class OperatorNameSyntax(public var name: String) : TerminalSyntax() + +/** + * Represents a usage of special symbols. + * + * @property kind The kind of symbol. + * @author Iaroslav Postovalov + */ +public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax() { + /** + * The kind of symbol. + */ + public enum class Kind { + /** + * The infinity (∞) symbol. + */ + INFINITY, + } +} + +/** + * Represents operand of a certain operator wrapped with parentheses or not. + * + * @property operand The operand. + * @property parentheses Whether the operand should be wrapped with parentheses. + * @author Iaroslav Postovalov + */ +public data class OperandSyntax( + public val operand: MathSyntax, + public var parentheses: Boolean, +) : MathSyntax() { + init { + operand.parent = this + } +} + +/** + * Represents unary, prefix operator syntax (like f x). + * + * @property prefix The prefix. + * @author Iaroslav Postovalov + */ +public data class UnaryOperatorSyntax( + public override val operation: String, + public var prefix: MathSyntax, + public override val operand: OperandSyntax, +) : UnarySyntax() { + init { + operand.parent = this + } +} + +/** + * Represents prefix, unary plus operator. + * + * @author Iaroslav Postovalov + */ +public data class UnaryPlusSyntax( + public override val operation: String, + public override val operand: OperandSyntax, +) : UnarySyntax() { + init { + operand.parent = this + } +} + +/** + * Represents prefix, unary minus operator. + * + * @author Iaroslav Postovalov + */ +public data class UnaryMinusSyntax( + public override val operation: String, + public override val operand: OperandSyntax, +) : UnarySyntax() { + init { + operand.parent = this + } +} + +/** + * Represents radical with a node inside it. + * + * @property operand The radicand. + * @author Iaroslav Postovalov + */ +public data class RadicalSyntax( + public override val operation: String, + public override val operand: MathSyntax, +) : UnarySyntax() { + init { + operand.parent = this + } +} + +/** + * Represents a syntax node with superscript (usually, for exponentiation). + * + * @property left The node. + * @property right The superscript. + * @author Iaroslav Postovalov + */ +public data class SuperscriptSyntax( + public override val operation: String, + public override val left: MathSyntax, + public override val right: MathSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents a syntax node with subscript. + * + * @property left The node. + * @property right The subscript. + * @author Iaroslav Postovalov + */ +public data class SubscriptSyntax( + public override val operation: String, + public override val left: MathSyntax, + public override val right: MathSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents binary, prefix operator syntax (like f(a, b)). + * + * @property prefix The prefix. + * @author Iaroslav Postovalov + */ +public data class BinaryOperatorSyntax( + public override val operation: String, + public var prefix: MathSyntax, + public override val left: MathSyntax, + public override val right: MathSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents binary, infix addition. + * + * @param left The augend. + * @param right The addend. + * @author Iaroslav Postovalov + */ +public data class BinaryPlusSyntax( + public override val operation: String, + public override val left: OperandSyntax, + public override val right: OperandSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents binary, infix subtraction. + * + * @param left The minuend. + * @param right The subtrahend. + * @author Iaroslav Postovalov + */ +public data class BinaryMinusSyntax( + public override val operation: String, + public override val left: OperandSyntax, + public override val right: OperandSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents fraction with numerator and denominator. + * + * @property left The numerator. + * @property right The denominator. + * @author Iaroslav Postovalov + */ +public data class FractionSyntax( + public override val operation: String, + public override val left: MathSyntax, + public override val right: MathSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents radical syntax with index. + * + * @property left The index. + * @property right The radicand. + * @author Iaroslav Postovalov + */ +public data class RadicalWithIndexSyntax( + public override val operation: String, + public override val left: MathSyntax, + public override val right: MathSyntax, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} + +/** + * Represents binary, infix multiplication in the form of coefficient (2 x) or with operator (x×2). + * + * @property left The multiplicand. + * @property right The multiplier. + * @property times whether the times (×) symbol should be used. + * @author Iaroslav Postovalov + */ +public data class MultiplicationSyntax( + public override val operation: String, + public override val left: OperandSyntax, + public override val right: OperandSyntax, + public var times: Boolean, +) : BinarySyntax() { + init { + left.parent = this + right.parent = this + } +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt new file mode 100644 index 000000000..fcc79f76b --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/SyntaxRenderer.kt @@ -0,0 +1,25 @@ +package space.kscience.kmath.ast.rendering + +/** + * Abstraction of writing [MathSyntax] as a string of an actual markup language. Typical implementation should + * involve traversal of MathSyntax with handling each its subtype. + * + * @author Iaroslav Postovalov + */ +public fun interface SyntaxRenderer { + /** + * Renders the [MathSyntax] to [output]. + */ + public fun render(node: MathSyntax, output: Appendable) +} + +/** + * Calls [SyntaxRenderer.render] with given [node] and a new [StringBuilder] instance, and returns its content. + * + * @author Iaroslav Postovalov + */ +public fun SyntaxRenderer.renderWithStringBuilder(node: MathSyntax): String { + val sb = StringBuilder() + render(node, sb) + return sb.toString() +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt new file mode 100644 index 000000000..6e66d3ca3 --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt @@ -0,0 +1,312 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.MST +import space.kscience.kmath.ast.rendering.FeaturedMathRenderer.RenderFeature +import space.kscience.kmath.operations.* +import kotlin.reflect.KClass + +/** + * Prints any [MST.Symbolic] as a [SymbolSyntax] containing the [MST.Symbolic.value] of it. + * + * @author Iaroslav Postovalov + */ +public object PrintSymbolic : RenderFeature { + public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Symbolic) return null + return SymbolSyntax(string = node.value) + } +} + +/** + * Prints any [MST.Numeric] as a [NumberSyntax] containing the [Any.toString] result of it. + * + * @author Iaroslav Postovalov + */ +public object PrintNumeric : RenderFeature { + public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Numeric) return null + return NumberSyntax(string = node.value.toString()) + } +} + +private fun printSignedNumberString(s: String): MathSyntax { + if (s.startsWith('-')) + return UnaryMinusSyntax( + operation = GroupOperations.MINUS_OPERATION, + operand = OperandSyntax( + operand = NumberSyntax(string = s.removePrefix("-")), + parentheses = true, + ), + ) + + return NumberSyntax(string = s) +} + +/** + * Special printing for numeric types which are printed in form of + * *('-'? (DIGIT+ ('.' DIGIT+)? ('E' '-'? DIGIT+)? | 'Infinity')) | 'NaN'*. + * + * @property types The suitable types. + */ +public class PrettyPrintFloats(public val types: Set>) : RenderFeature { + public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Numeric || node.value::class !in types) return null + val toString = node.value.toString().removeSuffix(".0") + + if ('E' in toString) { + val (beforeE, afterE) = toString.split('E') + val significand = beforeE.toDouble().toString().removeSuffix(".0") + val exponent = afterE.toDouble().toString().removeSuffix(".0") + + return MultiplicationSyntax( + operation = RingOperations.TIMES_OPERATION, + left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true), + right = OperandSyntax( + operand = SuperscriptSyntax( + operation = PowerOperations.POW_OPERATION, + left = NumberSyntax(string = "10"), + right = printSignedNumberString(exponent), + ), + parentheses = true, + ), + times = true, + ) + } + + if (toString.endsWith("Infinity")) { + val infty = SpecialSymbolSyntax(SpecialSymbolSyntax.Kind.INFINITY) + + if (toString.startsWith('-')) + return UnaryMinusSyntax( + operation = GroupOperations.MINUS_OPERATION, + operand = OperandSyntax(operand = infty, parentheses = true), + ) + + return infty + } + + return printSignedNumberString(toString) + } + + public companion object { + /** + * The default instance containing [Float], and [Double]. + */ + public val Default: PrettyPrintFloats = PrettyPrintFloats(setOf(Float::class, Double::class)) + } +} + +/** + * Special printing for numeric types which are printed in form of *'-'? DIGIT+*. + * + * @property types The suitable types. + */ +public class PrettyPrintIntegers(public val types: Set>) : RenderFeature { + public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Numeric || node.value::class !in types) + return null + + return printSignedNumberString(node.value.toString()) + } + + public companion object { + /** + * The default instance containing [Byte], [Short], [Int], and [Long]. + */ + public val Default: PrettyPrintIntegers = + PrettyPrintIntegers(setOf(Byte::class, Short::class, Int::class, Long::class)) + } +} + +/** + * Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is + * not [MST.Unary]. + * + * @param operations the allowed operations. If `null`, any operation is accepted. + */ +public abstract class Unary(public val operations: Collection?) : RenderFeature { + /** + * The actual render function. + */ + protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax? + + public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Unary || operations != null && node.operation !in operations) return null + return render0(renderer, node) + } +} + +/** + * Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is + * not [MST.Binary]. + * + * @property operations the allowed operations. If `null`, any operation is accepted. + */ +public abstract class Binary(public val operations: Collection?) : RenderFeature { + /** + * The actual render function. + */ + protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax? + + public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Binary || operations != null && node.operation !in operations) return null + return render0(renderer, node) + } +} + +public class BinaryPlus(operations: Collection?) : Binary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryPlusSyntax( + operation = node.operation, + left = OperandSyntax(parent.render(node.left), true), + right = OperandSyntax(parent.render(node.right), true), + ) + + public companion object { + public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION)) + } +} + +public class BinaryMinus(operations: Collection?) : Binary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryMinusSyntax( + operation = node.operation, + left = OperandSyntax(operand = parent.render(node.left), parentheses = true), + right = OperandSyntax(operand = parent.render(node.right), parentheses = true), + ) + + public companion object { + public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION)) + } +} + +public class UnaryPlus(operations: Collection?) : Unary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryPlusSyntax( + operation = node.operation, + operand = OperandSyntax(operand = parent.render(node.value), parentheses = true), + ) + + public companion object { + public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION)) + } +} + +public class UnaryMinus(operations: Collection?) : Unary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryMinusSyntax( + operation = node.operation, + operand = OperandSyntax(operand = parent.render(node.value), parentheses = true), + ) + + public companion object { + public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION)) + } +} + +public class Fraction(operations: Collection?) : Binary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = FractionSyntax( + operation = node.operation, + left = parent.render(node.left), + right = parent.render(node.right), + ) + + public companion object { + public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION)) + } +} + +public class BinaryOperator(operations: Collection?) : Binary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryOperatorSyntax( + operation = node.operation, + prefix = OperatorNameSyntax(name = node.operation), + left = parent.render(node.left), + right = parent.render(node.right), + ) + + public companion object { + public val Default: BinaryOperator = BinaryOperator(null) + } +} + +public class UnaryOperator(operations: Collection?) : Unary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax( + operation = node.operation, + prefix = OperatorNameSyntax(node.operation), + operand = OperandSyntax(parent.render(node.value), true), + ) + + public companion object { + public val Default: UnaryOperator = UnaryOperator(null) + } +} + +public class Power(operations: Collection?) : Binary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = SuperscriptSyntax( + operation = node.operation, + left = OperandSyntax(parent.render(node.left), true), + right = OperandSyntax(parent.render(node.right), true), + ) + + public companion object { + public val Default: Power = Power(setOf(PowerOperations.POW_OPERATION)) + } +} + +public class SquareRoot(operations: Collection?) : Unary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = + RadicalSyntax(operation = node.operation, operand = parent.render(node.value)) + + public companion object { + public val Default: SquareRoot = SquareRoot(setOf(PowerOperations.SQRT_OPERATION)) + } +} + +public class Exponential(operations: Collection?) : Unary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = SuperscriptSyntax( + operation = node.operation, + left = SymbolSyntax(string = "e"), + right = parent.render(node.value), + ) + + public companion object { + public val Default: Exponential = Exponential(setOf(ExponentialOperations.EXP_OPERATION)) + } +} + +public class Multiplication(operations: Collection?) : Binary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = MultiplicationSyntax( + operation = node.operation, + left = OperandSyntax(operand = parent.render(node.left), parentheses = true), + right = OperandSyntax(operand = parent.render(node.right), parentheses = true), + times = true, + ) + + public companion object { + public val Default: Multiplication = Multiplication(setOf( + RingOperations.TIMES_OPERATION, + )) + } +} + +public class InverseTrigonometricOperations(operations: Collection?) : Unary(operations) { + public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax( + operation = node.operation, + prefix = SuperscriptSyntax( + operation = PowerOperations.POW_OPERATION, + left = OperatorNameSyntax(name = node.operation.removePrefix("a")), + right = UnaryMinusSyntax( + operation = GroupOperations.MINUS_OPERATION, + operand = OperandSyntax(operand = NumberSyntax(string = "1"), parentheses = true), + ), + ), + operand = OperandSyntax(operand = parent.render(node.value), parentheses = true), + ) + + public companion object { + public val Default: InverseTrigonometricOperations = InverseTrigonometricOperations(setOf( + TrigonometricOperations.ACOS_OPERATION, + TrigonometricOperations.ASIN_OPERATION, + TrigonometricOperations.ATAN_OPERATION, + ExponentialOperations.ACOSH_OPERATION, + ExponentialOperations.ASINH_OPERATION, + ExponentialOperations.ATANH_OPERATION, + )) + } +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/stages.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/stages.kt new file mode 100644 index 000000000..c183f6ace --- /dev/null +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/stages.kt @@ -0,0 +1,197 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.operations.FieldOperations +import space.kscience.kmath.operations.GroupOperations +import space.kscience.kmath.operations.PowerOperations +import space.kscience.kmath.operations.RingOperations + +/** + * Removes unnecessary times (×) symbols from [MultiplicationSyntax]. + * + * @author Iaroslav Postovalov + */ +public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostProcessStage { + public override fun perform(node: MathSyntax) { + when (node) { + is NumberSyntax -> Unit + is SymbolSyntax -> Unit + is OperatorNameSyntax -> Unit + is SpecialSymbolSyntax -> Unit + is OperandSyntax -> perform(node.operand) + + is UnaryOperatorSyntax -> { + perform(node.prefix) + perform(node.operand) + } + + is UnaryPlusSyntax -> perform(node.operand) + is UnaryMinusSyntax -> perform(node.operand) + is RadicalSyntax -> perform(node.operand) + + is SuperscriptSyntax -> { + perform(node.left) + perform(node.right) + } + + is SubscriptSyntax -> { + perform(node.left) + perform(node.right) + } + + is BinaryOperatorSyntax -> { + perform(node.prefix) + perform(node.left) + perform(node.right) + } + + is BinaryPlusSyntax -> { + perform(node.left) + perform(node.right) + } + + is BinaryMinusSyntax -> { + perform(node.left) + perform(node.right) + } + + is FractionSyntax -> { + perform(node.left) + perform(node.right) + } + + is RadicalWithIndexSyntax -> { + perform(node.left) + perform(node.right) + } + + is MultiplicationSyntax -> { + node.times = node.right.operand is NumberSyntax && !node.right.parentheses + || node.left.operand is NumberSyntax && node.right.operand is FractionSyntax + || node.left.operand is NumberSyntax && node.right.operand is NumberSyntax + || node.left.operand is NumberSyntax && node.right.operand is SuperscriptSyntax && node.right.operand.left is NumberSyntax + + perform(node.left) + perform(node.right) + } + } + } +} + +/** + * Removes unnecessary parentheses from [OperandSyntax]. + * + * @property precedenceFunction Returns the precedence number for syntax node. Higher number is lower priority. + * @author Iaroslav Postovalov + */ +public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> Int) : + FeaturedMathRendererWithPostProcess.PostProcessStage { + public override fun perform(node: MathSyntax) { + when (node) { + is NumberSyntax -> Unit + is SymbolSyntax -> Unit + is OperatorNameSyntax -> Unit + is SpecialSymbolSyntax -> Unit + + is OperandSyntax -> { + val isRightOfSuperscript = + (node.parent is SuperscriptSyntax) && (node.parent as SuperscriptSyntax).right === node + + val precedence = precedenceFunction(node.operand) + + val needParenthesesByPrecedence = when (val parent = node.parent) { + null -> false + + is BinarySyntax -> { + val parentPrecedence = precedenceFunction(parent) + + parentPrecedence < precedence || + parentPrecedence == precedence && parentPrecedence != 0 && node === parent.right + } + + else -> precedence > precedenceFunction(parent) + } + + node.parentheses = !isRightOfSuperscript + && (needParenthesesByPrecedence || node.parent is UnaryOperatorSyntax) + + perform(node.operand) + } + + is UnaryOperatorSyntax -> { + perform(node.prefix) + perform(node.operand) + } + + is UnaryPlusSyntax -> perform(node.operand) + is UnaryMinusSyntax -> { + perform(node.operand) + } + is RadicalSyntax -> perform(node.operand) + + is SuperscriptSyntax -> { + perform(node.left) + perform(node.right) + } + + is SubscriptSyntax -> { + perform(node.left) + perform(node.right) + } + + is BinaryOperatorSyntax -> { + perform(node.prefix) + perform(node.left) + perform(node.right) + } + + is BinaryPlusSyntax -> { + perform(node.left) + perform(node.right) + } + + is BinaryMinusSyntax -> { + perform(node.left) + perform(node.right) + } + + is FractionSyntax -> { + perform(node.left) + perform(node.right) + } + + is MultiplicationSyntax -> { + perform(node.left) + perform(node.right) + } + + is RadicalWithIndexSyntax -> { + perform(node.left) + perform(node.right) + } + } + } + + public companion object { + /** + * The default configuration of [SimplifyParentheses] where power is 1, multiplicative operations are 2, + * additive operations are 3. + */ + public val Default: SimplifyParentheses = SimplifyParentheses { + when (it) { + is TerminalSyntax -> 0 + is UnarySyntax -> 2 + + is BinarySyntax -> when (it.operation) { + PowerOperations.POW_OPERATION -> 1 + RingOperations.TIMES_OPERATION -> 3 + FieldOperations.DIV_OPERATION -> 3 + GroupOperations.MINUS_OPERATION -> 4 + GroupOperations.PLUS_OPERATION -> 4 + else -> 0 + } + + else -> 0 + } + } + } +} diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt index 0bd9a386d..456a2ba07 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt @@ -68,7 +68,7 @@ internal fun MST.compileWith(algebra: Algebra): Expression { /** * Compiles an [MST] to ESTree generated expression using given algebra. * - * @author Alexander Nozik. + * @author Iaroslav Postovalov */ public fun Algebra.expression(mst: MST): Expression = mst.compileWith(this) @@ -76,7 +76,7 @@ public fun Algebra.expression(mst: MST): Expression = /** * Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression. * - * @author Alexander Nozik. + * @author Iaroslav Postovalov */ public fun MstExpression>.compile(): Expression = mst.compileWith(algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 8875bd715..369fe136b 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -73,7 +73,7 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp /** * Compiles an [MST] to ASM using given algebra. * - * @author Alexander Nozik. + * @author Alexander Nozik */ public inline fun Algebra.expression(mst: MST): Expression = mst.compileWith(T::class.java, this) @@ -81,7 +81,7 @@ public inline fun Algebra.expression(mst: MST): Expression< /** * Optimizes performance of an [MstExpression] using ASM codegen. * - * @author Alexander Nozik. + * @author Alexander Nozik */ public inline fun MstExpression>.compile(): Expression = mst.compileWith(T::class.java, algebra) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/parser.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/parser.kt index 9a38ce81a..8ecb0adda 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/parser.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/ast/parser.kt @@ -21,7 +21,8 @@ import space.kscience.kmath.operations.RingOperations /** * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. * - * @author Alexander Nozik and Iaroslav Postovalov + * @author Alexander Nozik + * @author Iaroslav Postovalov */ public object ArithmeticsEvaluator : Grammar() { // TODO replace with "...".toRegex() when better-parse 0.4.1 is released diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt new file mode 100644 index 000000000..b10f7ed4e --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt @@ -0,0 +1,90 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.MST.Numeric +import space.kscience.kmath.ast.rendering.TestUtils.testLatex +import kotlin.test.Test + +internal class TestFeatures { + @Test + fun printSymbolic() = testLatex("x", "x") + + @Test + fun printNumeric() { + val num = object : Number() { + override fun toByte(): Byte = throw UnsupportedOperationException() + override fun toChar(): Char = throw UnsupportedOperationException() + override fun toDouble(): Double = throw UnsupportedOperationException() + override fun toFloat(): Float = throw UnsupportedOperationException() + override fun toInt(): Int = throw UnsupportedOperationException() + override fun toLong(): Long = throw UnsupportedOperationException() + override fun toShort(): Short = throw UnsupportedOperationException() + override fun toString(): String = "foo" + } + + testLatex(Numeric(num), "foo") + } + + @Test + fun prettyPrintFloats() { + testLatex(Numeric(Double.NaN), "NaN") + testLatex(Numeric(Double.POSITIVE_INFINITY), "\\infty") + testLatex(Numeric(Double.NEGATIVE_INFINITY), "-\\infty") + testLatex(Numeric(1.0), "1") + testLatex(Numeric(-1.0), "-1") + testLatex(Numeric(1.42), "1.42") + testLatex(Numeric(-1.42), "-1.42") + testLatex(Numeric(1.1e10), "1.1\\times10^{10}") + testLatex(Numeric(1.1e-10), "1.1\\times10^{-10}") + testLatex(Numeric(-1.1e-10), "-1.1\\times10^{-10}") + testLatex(Numeric(-1.1e10), "-1.1\\times10^{10}") + } + + @Test + fun prettyPrintIntegers() { + testLatex(Numeric(42), "42") + testLatex(Numeric(-42), "-42") + } + + @Test + fun binaryPlus() = testLatex("2+2", "2+2") + + @Test + fun binaryMinus() = testLatex("2-2", "2-2") + + @Test + fun fraction() = testLatex("2/2", "\\frac{2}{2}") + + @Test + fun binaryOperator() = testLatex("f(x, y)", "\\operatorname{f}\\left(x,y\\right)") + + @Test + fun unaryOperator() = testLatex("f(x)", "\\operatorname{f}\\,\\left(x\\right)") + + @Test + fun power() = testLatex("x^y", "x^{y}") + + @Test + fun squareRoot() = testLatex("sqrt(x)", "\\sqrt{x}") + + @Test + fun exponential() = testLatex("exp(x)", "e^{x}") + + @Test + fun multiplication() = testLatex("x*1", "x\\times1") + + @Test + fun inverseTrigonometry() { + testLatex("asin(x)", "\\operatorname{sin}^{-1}\\,\\left(x\\right)") + testLatex("asinh(x)", "\\operatorname{sinh}^{-1}\\,\\left(x\\right)") + testLatex("acos(x)", "\\operatorname{cos}^{-1}\\,\\left(x\\right)") + testLatex("acosh(x)", "\\operatorname{cosh}^{-1}\\,\\left(x\\right)") + testLatex("atan(x)", "\\operatorname{tan}^{-1}\\,\\left(x\\right)") + testLatex("atanh(x)", "\\operatorname{tanh}^{-1}\\,\\left(x\\right)") + } + +// @Test +// fun unaryPlus() { +// testLatex("+1", "+1") +// testLatex("+1", "++1") +// } +} diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt new file mode 100644 index 000000000..9c1009042 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt @@ -0,0 +1,65 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.MST +import space.kscience.kmath.ast.rendering.TestUtils.testLatex +import space.kscience.kmath.operations.GroupOperations +import kotlin.test.Test + +internal class TestLatex { + @Test + fun number() = testLatex("42", "42") + + @Test + fun symbol() = testLatex("x", "x") + + @Test + fun operatorName() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") + + @Test + fun specialSymbol() = testLatex(MST.Numeric(Double.POSITIVE_INFINITY), "\\infty") + + @Test + fun operand() { + testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") + testLatex("1+1", "1+1") + } + + @Test + fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") + + @Test + fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") + + @Test + fun unaryMinus() = testLatex("-x", "-x") + + @Test + fun radical() = testLatex("sqrt(x)", "\\sqrt{x}") + + @Test + fun superscript() = testLatex("x^y", "x^{y}") + + @Test + fun subscript() = testLatex(SubscriptSyntax("", SymbolSyntax("x"), NumberSyntax("123")), "x_{123}") + + @Test + fun binaryOperator() = testLatex("f(x, y)", "\\operatorname{f}\\left(x,y\\right)") + + @Test + fun binaryPlus() = testLatex("x+x", "x+x") + + @Test + fun binaryMinus() = testLatex("x-x", "x-x") + + @Test + fun fraction() = testLatex("x/x", "\\frac{x}{x}") + + @Test + fun radicalWithIndex() = testLatex(RadicalWithIndexSyntax("", SymbolSyntax("x"), SymbolSyntax("y")), "\\sqrt[x]{y}") + + @Test + fun multiplication() { + testLatex("x*1", "x\\times1") + testLatex("1*x", "1\\,x") + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt new file mode 100644 index 000000000..c9a462840 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt @@ -0,0 +1,84 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.MST +import space.kscience.kmath.ast.rendering.TestUtils.testMathML +import space.kscience.kmath.operations.GroupOperations +import kotlin.test.Test + +internal class TestMathML { + @Test + fun number() = testMathML("42", "42") + + @Test + fun symbol() = testMathML("x", "x") + + @Test + fun operatorName() = testMathML( + "sin(1)", + "sin1", + ) + + @Test + fun specialSymbol() = testMathML(MST.Numeric(Double.POSITIVE_INFINITY), "") + + @Test + fun operand() { + testMathML( + "sin(1)", + "sin1", + ) + + testMathML("1+1", "1+1") + } + + @Test + fun unaryOperator() = testMathML( + "sin(1)", + "sin1", + ) + + @Test + fun unaryPlus() = + testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") + + @Test + fun unaryMinus() = testMathML("-x", "-x") + + @Test + fun radical() = testMathML("sqrt(x)", "x") + + @Test + fun superscript() = testMathML("x^y", "xy") + + @Test + fun subscript() = testMathML( + SubscriptSyntax("", SymbolSyntax("x"), NumberSyntax("123")), + "x123", + ) + + @Test + fun binaryOperator() = testMathML( + "f(x, y)", + "fx,y", + ) + + @Test + fun binaryPlus() = testMathML("x+x", "x+x") + + @Test + fun binaryMinus() = testMathML("x-x", "x-x") + + @Test + fun fraction() = testMathML("x/x", "xx") + + @Test + fun radicalWithIndex() = + testMathML(RadicalWithIndexSyntax("", SymbolSyntax("x"), SymbolSyntax("y")), + "yx") + + @Test + fun multiplication() { + testMathML("x*1", "x×1") + testMathML("1*x", "1x") + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt new file mode 100644 index 000000000..56a799c2c --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestStages.kt @@ -0,0 +1,28 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.rendering.TestUtils.testLatex +import kotlin.test.Test + +internal class TestStages { + @Test + fun betterMultiplication() { + testLatex("a*1", "a\\times1") + testLatex("1*(2/3)", "1\\times\\left(\\frac{2}{3}\\right)") + testLatex("1*1", "1\\times1") + testLatex("2e10", "2\\times10^{10}") + testLatex("2*x", "2\\,x") + testLatex("2*(x+1)", "2\\,\\left(x+1\\right)") + testLatex("x*y", "x\\,y") + } + + @Test + fun parentheses() { + testLatex("(x+1)", "x+1") + testLatex("x*x*x", "x\\,x\\,x") + testLatex("(x+x)*x", "\\left(x+x\\right)\\,x") + testLatex("x+x*x", "x+x\\,x") + testLatex("x+x^x*x+x", "x+x^{x}\\,x+x") + testLatex("(x+x)^x+x*x", "\\left(x+x\\right)^{x}+x\\,x") + testLatex("x^(x+x)", "x^{x+x}") + } +} diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt new file mode 100644 index 000000000..e6359bcc9 --- /dev/null +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestUtils.kt @@ -0,0 +1,41 @@ +package space.kscience.kmath.ast.rendering + +import space.kscience.kmath.ast.MST +import space.kscience.kmath.ast.parseMath +import kotlin.test.assertEquals + +internal object TestUtils { + private fun mathSyntax(mst: MST) = FeaturedMathRendererWithPostProcess.Default.render(mst) + private fun latex(mst: MST) = LatexSyntaxRenderer.renderWithStringBuilder(mathSyntax(mst)) + private fun mathML(mst: MST) = MathMLSyntaxRenderer.renderWithStringBuilder(mathSyntax(mst)) + + internal fun testLatex(mst: MST, expectedLatex: String) = assertEquals( + expected = expectedLatex, + actual = latex(mst), + ) + + internal fun testLatex(expression: String, expectedLatex: String) = assertEquals( + expected = expectedLatex, + actual = latex(expression.parseMath()), + ) + + internal fun testLatex(expression: MathSyntax, expectedLatex: String) = assertEquals( + expected = expectedLatex, + actual = LatexSyntaxRenderer.renderWithStringBuilder(expression), + ) + + internal fun testMathML(mst: MST, expectedMathML: String) = assertEquals( + expected = "$expectedMathML", + actual = mathML(mst), + ) + + internal fun testMathML(expression: String, expectedMathML: String) = assertEquals( + expected = "$expectedMathML", + actual = mathML(expression.parseMath()), + ) + + internal fun testMathML(expression: MathSyntax, expectedMathML: String) = assertEquals( + expected = "$expectedMathML", + actual = MathMLSyntaxRenderer.renderWithStringBuilder(expression), + ) +} diff --git a/kmath-complex/README.md b/kmath-complex/README.md index d7b2937fd..ec5bf289f 100644 --- a/kmath-complex/README.md +++ b/kmath-complex/README.md @@ -8,7 +8,7 @@ Complex and hypercomplex number systems in KMath. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-4`. **Gradle:** ```gradle @@ -19,7 +19,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-complex:0.3.0-dev-3' + implementation 'space.kscience:kmath-complex:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -31,6 +31,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-complex:0.3.0-dev-3") + implementation("space.kscience:kmath-complex:0.3.0-dev-4") } ``` diff --git a/kmath-core/README.md b/kmath-core/README.md index 096c7d833..5e4f1765d 100644 --- a/kmath-core/README.md +++ b/kmath-core/README.md @@ -15,7 +15,7 @@ performance calculations to code generation. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-4`. **Gradle:** ```gradle @@ -26,7 +26,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-core:0.3.0-dev-3' + implementation 'space.kscience:kmath-core:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -38,6 +38,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-core:0.3.0-dev-3") + implementation("space.kscience:kmath-core:0.3.0-dev-4") } ``` diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt index 18fbf0fdd..817bc9f9c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/BigInt.kt @@ -18,7 +18,8 @@ private typealias TBase = ULong /** * Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger). * - * @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai) + * @author Robert Drynkin + * @author Peter Klimai */ @OptIn(UnstableKMathAPI::class) public object BigIntField : Field, NumbersAddOperations, ScaleOperations { diff --git a/kmath-ejml/README.md b/kmath-ejml/README.md index 2551703a4..1f13a03c5 100644 --- a/kmath-ejml/README.md +++ b/kmath-ejml/README.md @@ -9,7 +9,7 @@ EJML based linear algebra implementation. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-4`. **Gradle:** ```gradle @@ -20,7 +20,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-ejml:0.3.0-dev-3' + implementation 'space.kscience:kmath-ejml:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -32,6 +32,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-ejml:0.3.0-dev-3") + implementation("space.kscience:kmath-ejml:0.3.0-dev-4") } ``` diff --git a/kmath-for-real/README.md b/kmath-for-real/README.md index ad3d33062..f9c6ed3a0 100644 --- a/kmath-for-real/README.md +++ b/kmath-for-real/README.md @@ -9,7 +9,7 @@ Specialization of KMath APIs for Double numbers. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-4`. **Gradle:** ```gradle @@ -20,7 +20,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-for-real:0.3.0-dev-3' + implementation 'space.kscience:kmath-for-real:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -32,6 +32,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-for-real:0.3.0-dev-3") + implementation("space.kscience:kmath-for-real:0.3.0-dev-4") } ``` diff --git a/kmath-functions/README.md b/kmath-functions/README.md index 531e97a44..1e4b06e0f 100644 --- a/kmath-functions/README.md +++ b/kmath-functions/README.md @@ -10,7 +10,7 @@ Functions and interpolations. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-4`. **Gradle:** ```gradle @@ -21,7 +21,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-functions:0.3.0-dev-3' + implementation 'space.kscience:kmath-functions:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -33,6 +33,6 @@ repositories { } dependencies { - implementation("space.kscience:kmath-functions:0.3.0-dev-3") + implementation("space.kscience:kmath-functions:0.3.0-dev-4") } ``` diff --git a/kmath-nd4j/README.md b/kmath-nd4j/README.md index 938d05c33..c8944f1ab 100644 --- a/kmath-nd4j/README.md +++ b/kmath-nd4j/README.md @@ -9,7 +9,7 @@ ND4J based implementations of KMath abstractions. ## Artifact: -The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-3`. +The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-4`. **Gradle:** ```gradle @@ -20,7 +20,7 @@ repositories { } dependencies { - implementation 'space.kscience:kmath-nd4j:0.3.0-dev-3' + implementation 'space.kscience:kmath-nd4j:0.3.0-dev-4' } ``` **Gradle Kotlin DSL:** @@ -32,7 +32,7 @@ repositories { } dependencies { - implementation("space.kscience:kmath-nd4j:0.3.0-dev-3") + implementation("space.kscience:kmath-nd4j:0.3.0-dev-4") } ``` From c2bab5d13857ac9d6a5a00f9667d1e21ce8bee95 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 1 Apr 2021 18:18:54 +0300 Subject: [PATCH 2/9] Fix Samplers and distribution API --- .../kmath/commons/fit/fitWithAutoDiff.kt | 2 +- .../kmath/stat/DistributionBenchmark.kt | 6 +- .../kscience/kmath/stat/DistributionDemo.kt | 2 +- .../commons/optimization/OptimizeTest.kt | 2 +- .../kscience/kmath/chains/BlockingChain.kt | 50 ++++ .../kmath/chains/BlockingDoubleChain.kt | 22 +- .../kscience/kmath/chains/BlockingIntChain.kt | 11 +- .../space/kscience/kmath/chains/Chain.kt | 28 +-- .../kscience/kmath/streaming/BufferFlow.kt | 3 +- kmath-stat/build.gradle.kts | 4 + .../kmath/distributions/Distribution.kt | 38 +++ .../FactorizedDistribution.kt | 3 +- .../distributions/NormalDistribution.kt | 15 +- .../kmath/{stat => }/internal/InternalErf.kt | 2 +- .../{stat => }/internal/InternalGamma.kt | 2 +- .../{stat => }/internal/InternalUtils.kt | 2 +- .../AhrensDieterExponentialSampler.kt | 72 ++++++ .../AhrensDieterMarsagliaTsangGammaSampler.kt | 4 +- .../samplers/AliasMethodDiscreteSampler.kt | 220 +++++++++--------- .../kmath/samplers/BoxMullerSampler.kt | 52 +++++ .../kmath/samplers/ConstantSampler.kt | 13 ++ .../kmath/samplers/GaussianSampler.kt | 34 +++ .../samplers/KempSmallMeanPoissonSampler.kt | 68 ++++++ .../MarsagliaNormalizedGaussianSampler.kt | 61 +++++ .../samplers/NormalizedGaussianSampler.kt | 18 ++ .../kscience/kmath/samplers/PoissonSampler.kt | 203 ++++++++++++++++ .../ZigguratNormalizedGaussianSampler.kt | 88 +++++++ .../space/kscience/kmath/stat/RandomChain.kt | 24 +- .../kscience/kmath/stat/RandomGenerator.kt | 6 + .../stat/{Distribution.kt => Sampler.kt} | 46 +--- .../space/kscience/kmath/stat/Statistic.kt | 2 +- .../kmath/stat/UniformDistribution.kt | 2 + .../AhrensDieterExponentialSampler.kt | 73 ------ .../BoxMullerNormalizedGaussianSampler.kt | 48 ---- .../kmath/stat/samplers/GaussianSampler.kt | 43 ---- .../samplers/KempSmallMeanPoissonSampler.kt | 63 ----- .../stat/samplers/LargeMeanPoissonSampler.kt | 130 ----------- .../MarsagliaNormalizedGaussianSampler.kt | 61 ----- .../samplers/NormalizedGaussianSampler.kt | 9 - .../kmath/stat/samplers/PoissonSampler.kt | 30 --- .../stat/samplers/SmallMeanPoissonSampler.kt | 50 ---- .../ZigguratNormalizedGaussianSampler.kt | 88 ------- .../kmath/stat/CommonsDistributionsTest.kt | 19 +- .../kscience/kmath/stat/StatisticTest.kt | 2 +- 44 files changed, 915 insertions(+), 806 deletions(-) create mode 100644 kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => distributions}/FactorizedDistribution.kt (94%) rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => }/distributions/NormalDistribution.kt (71%) rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => }/internal/InternalErf.kt (90%) rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => }/internal/InternalGamma.kt (99%) rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => }/internal/InternalUtils.kt (98%) create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => }/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt (97%) rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/{stat => }/samplers/AliasMethodDiscreteSampler.kt (58%) create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt create mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt rename kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/{Distribution.kt => Sampler.kt} (54%) delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterExponentialSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/BoxMullerNormalizedGaussianSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/GaussianSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/KempSmallMeanPoissonSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/LargeMeanPoissonSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/MarsagliaNormalizedGaussianSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/NormalizedGaussianSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/PoissonSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/SmallMeanPoissonSampler.kt delete mode 100644 kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/ZigguratNormalizedGaussianSampler.kt diff --git a/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt b/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt index 04c55b34c..02534ac98 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt @@ -7,6 +7,7 @@ import kscience.plotly.models.ScatterMode import kscience.plotly.models.TraceValues import space.kscience.kmath.commons.optimization.chiSquared import space.kscience.kmath.commons.optimization.minimize +import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.misc.symbol import space.kscience.kmath.optimization.FunctionOptimization import space.kscience.kmath.optimization.OptimizationResult @@ -14,7 +15,6 @@ import space.kscience.kmath.real.DoubleVector import space.kscience.kmath.real.map import space.kscience.kmath.real.step import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.distributions.NormalDistribution import space.kscience.kmath.structures.asIterable import space.kscience.kmath.structures.toList import kotlin.math.pow diff --git a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt index bfd138502..5cf96adaa 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionBenchmark.kt @@ -3,8 +3,8 @@ package space.kscience.kmath.stat import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking -import space.kscience.kmath.stat.samplers.GaussianSampler import org.apache.commons.rng.simple.RandomSource +import space.kscience.kmath.samplers.GaussianSampler import java.time.Duration import java.time.Instant import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler @@ -12,8 +12,8 @@ import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSa private suspend fun runKMathChained(): Duration { val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) - val normal = GaussianSampler.of(7.0, 2.0) - val chain = normal.sample(generator).blocking() + val normal = GaussianSampler(7.0, 2.0) + val chain = normal.sample(generator) val startTime = Instant.now() var sum = 0.0 diff --git a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt index aac7d51d4..1e8542bd8 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/stat/DistributionDemo.kt @@ -3,7 +3,7 @@ package space.kscience.kmath.stat import kotlinx.coroutines.runBlocking import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.collectWithState -import space.kscience.kmath.stat.distributions.NormalDistribution +import space.kscience.kmath.distributions.NormalDistribution /** * The state of distribution averager. diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt index 36f2639f4..a51c407c2 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -2,10 +2,10 @@ package space.kscience.kmath.commons.optimization import kotlinx.coroutines.runBlocking import space.kscience.kmath.commons.expressions.DerivativeStructureExpression +import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.misc.symbol import space.kscience.kmath.optimization.FunctionOptimization import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.distributions.NormalDistribution import kotlin.math.pow import kotlin.test.Test diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt new file mode 100644 index 000000000..429175126 --- /dev/null +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingChain.kt @@ -0,0 +1,50 @@ +package space.kscience.kmath.chains + +import space.kscience.kmath.structures.Buffer + + +public interface BufferChain : Chain { + public suspend fun nextBuffer(size: Int): Buffer + override suspend fun fork(): BufferChain +} + +/** + * A chain with blocking generator that could be used without suspension + */ +public interface BlockingChain : Chain { + /** + * Get the next value without concurrency support. Not guaranteed to be thread safe. + */ + public fun nextBlocking(): T + + override suspend fun next(): T = nextBlocking() + + override suspend fun fork(): BlockingChain +} + + +public interface BlockingBufferChain : BlockingChain, BufferChain { + + public fun nextBufferBlocking(size: Int): Buffer + + public override fun nextBlocking(): T = nextBufferBlocking(1)[0] + + public override suspend fun nextBuffer(size: Int): Buffer = nextBufferBlocking(size) + + override suspend fun fork(): BlockingBufferChain +} + + +public suspend inline fun Chain.nextBuffer(size: Int): Buffer = if (this is BufferChain) { + nextBuffer(size) +} else { + Buffer.auto(size) { next() } +} + +public inline fun BlockingChain.nextBufferBlocking( + size: Int, +): Buffer = if (this is BlockingBufferChain) { + nextBufferBlocking(size) +} else { + Buffer.auto(size) { nextBlocking() } +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt index d024147b4..c2153ff6a 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingDoubleChain.kt @@ -1,13 +1,27 @@ package space.kscience.kmath.chains +import space.kscience.kmath.structures.DoubleBuffer + /** - * Chunked, specialized chain for real values. + * Chunked, specialized chain for double values, which supports blocking [nextBlocking] operation */ -public interface BlockingDoubleChain : Chain { - public override suspend fun next(): Double +public interface BlockingDoubleChain : BlockingBufferChain { /** * Returns an [DoubleArray] chunk of [size] values of [next]. */ - public suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() } + public override fun nextBufferBlocking(size: Int): DoubleBuffer + + override suspend fun fork(): BlockingDoubleChain + + public companion object } + +public fun BlockingDoubleChain.map(transform: (Double) -> Double): BlockingDoubleChain = object : BlockingDoubleChain { + override fun nextBufferBlocking(size: Int): DoubleBuffer { + val block = this@map.nextBufferBlocking(size) + return DoubleBuffer(size) { transform(block[it]) } + } + + override suspend fun fork(): BlockingDoubleChain = this@map.fork().map(transform) +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt index fb2e453ad..21a498646 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/BlockingIntChain.kt @@ -1,9 +1,12 @@ package space.kscience.kmath.chains +import space.kscience.kmath.structures.IntBuffer + /** * Performance optimized chain for integer values */ -public interface BlockingIntChain : Chain { - public override suspend fun next(): Int - public suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() } -} +public interface BlockingIntChain : BlockingBufferChain { + override fun nextBufferBlocking(size: Int): IntBuffer + + override suspend fun fork(): BlockingIntChain +} \ No newline at end of file diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt index a961f2e09..adeaea5a7 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/chains/Chain.kt @@ -24,20 +24,20 @@ import kotlinx.coroutines.sync.withLock /** * A not-necessary-Markov chain of some type - * @param R - the chain element type + * @param T - the chain element type */ -public interface Chain : Flow { +public interface Chain : Flow { /** * Generate next value, changing state if needed */ - public suspend fun next(): R + public suspend fun next(): T /** * Create a copy of current chain state. Consuming resulting chain does not affect initial chain */ - public fun fork(): Chain + public suspend fun fork(): Chain - override suspend fun collect(collector: FlowCollector): Unit = + override suspend fun collect(collector: FlowCollector): Unit = flow { while (true) emit(next()) }.collect(collector) public companion object @@ -51,7 +51,7 @@ public fun Sequence.asChain(): Chain = iterator().asChain() */ public class SimpleChain(private val gen: suspend () -> R) : Chain { public override suspend fun next(): R = gen() - public override fun fork(): Chain = this + public override suspend fun fork(): Chain = this } /** @@ -69,7 +69,7 @@ public class MarkovChain(private val seed: suspend () -> R, private newValue } - public override fun fork(): Chain = MarkovChain(seed = { value ?: seed() }, gen = gen) + public override suspend fun fork(): Chain = MarkovChain(seed = { value ?: seed() }, gen = gen) } /** @@ -94,7 +94,7 @@ public class StatefulChain( newValue } - public override fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) + public override suspend fun fork(): Chain = StatefulChain(forkState(state), seed, forkState, gen) } /** @@ -102,7 +102,7 @@ public class StatefulChain( */ public class ConstantChain(public val value: T) : Chain { public override suspend fun next(): T = value - public override fun fork(): Chain = this + public override suspend fun fork(): Chain = this } /** @@ -111,7 +111,7 @@ public class ConstantChain(public val value: T) : Chain { */ public fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { override suspend fun next(): R = func(this@map.next()) - override fun fork(): Chain = this@map.fork().map(func) + override suspend fun fork(): Chain = this@map.fork().map(func) } /** @@ -127,7 +127,7 @@ public fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain return next } - override fun fork(): Chain = this@filter.fork().filter(block) + override suspend fun fork(): Chain = this@filter.fork().filter(block) } /** @@ -135,7 +135,7 @@ public fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain */ public fun Chain.collect(mapper: suspend (Chain) -> R): Chain = object : Chain { override suspend fun next(): R = mapper(this@collect) - override fun fork(): Chain = this@collect.fork().collect(mapper) + override suspend fun fork(): Chain = this@collect.fork().collect(mapper) } public fun Chain.collectWithState( @@ -145,7 +145,7 @@ public fun Chain.collectWithState( ): Chain = object : Chain { override suspend fun next(): R = state.mapper(this@collectWithState) - override fun fork(): Chain = + override suspend fun fork(): Chain = this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper) } @@ -154,5 +154,5 @@ public fun Chain.collectWithState( */ public fun Chain.zip(other: Chain, block: suspend (T, U) -> R): Chain = object : Chain { override suspend fun next(): R = block(this@zip.next(), other.next()) - override fun fork(): Chain = this@zip.fork().zip(other.fork(), block) + override suspend fun fork(): Chain = this@zip.fork().zip(other.fork(), block) } diff --git a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt index dc1dd4757..655f94cdf 100644 --- a/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt +++ b/kmath-coroutines/src/commonMain/kotlin/space/kscience/kmath/streaming/BufferFlow.kt @@ -6,7 +6,6 @@ import space.kscience.kmath.chains.BlockingDoubleChain import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.DoubleBuffer -import space.kscience.kmath.structures.asBuffer /** * Create a [Flow] from buffer @@ -50,7 +49,7 @@ public fun Flow.chunked(bufferSize: Int): Flow = flow { if (this@chunked is BlockingDoubleChain) { // performance optimization for blocking primitive chain - while (true) emit(nextBlock(bufferSize).asBuffer()) + while (true) emit(nextBufferBlocking(bufferSize)) } else { val array = DoubleArray(bufferSize) var counter = 0 diff --git a/kmath-stat/build.gradle.kts b/kmath-stat/build.gradle.kts index bc3890b1e..c2ebb7ea1 100644 --- a/kmath-stat/build.gradle.kts +++ b/kmath-stat/build.gradle.kts @@ -2,6 +2,10 @@ plugins { id("ru.mipt.npm.gradle.mpp") } +kscience{ + useAtomic() +} + kotlin.sourceSets { commonMain { dependencies { diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt new file mode 100644 index 000000000..fcad8ef99 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/Distribution.kt @@ -0,0 +1,38 @@ +package space.kscience.kmath.distributions + +import space.kscience.kmath.chains.Chain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler + +/** + * A distribution of typed objects. + */ +public interface Distribution : Sampler { + /** + * A probability value for given argument [arg]. + * For continuous distributions returns PDF + */ + public fun probability(arg: T): Double + + public override fun sample(generator: RandomGenerator): Chain + + /** + * An empty companion. Distribution factories should be written as its extensions + */ + public companion object +} + +public interface UnivariateDistribution> : Distribution { + /** + * Cumulative distribution for ordered parameter (CDF) + */ + public fun cumulative(arg: T): Double +} + +/** + * Compute probability integral in an interval + */ +public fun > UnivariateDistribution.integral(from: T, to: T): Double { + require(to > from) + return cumulative(to) - cumulative(from) +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/FactorizedDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt similarity index 94% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/FactorizedDistribution.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt index 3dd506b67..e69086af4 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/FactorizedDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/FactorizedDistribution.kt @@ -1,7 +1,8 @@ -package space.kscience.kmath.stat +package space.kscience.kmath.distributions import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.SimpleChain +import space.kscience.kmath.stat.RandomGenerator /** * A multivariate distribution which takes a map of parameters diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/distributions/NormalDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt similarity index 71% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/distributions/NormalDistribution.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt index 6515cbaa7..15593aed5 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/distributions/NormalDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/distributions/NormalDistribution.kt @@ -1,12 +1,11 @@ -package space.kscience.kmath.stat.distributions +package space.kscience.kmath.distributions import space.kscience.kmath.chains.Chain +import space.kscience.kmath.internal.InternalErf +import space.kscience.kmath.samplers.GaussianSampler +import space.kscience.kmath.samplers.NormalizedGaussianSampler +import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.UnivariateDistribution -import space.kscience.kmath.stat.internal.InternalErf -import space.kscience.kmath.stat.samplers.GaussianSampler -import space.kscience.kmath.stat.samplers.NormalizedGaussianSampler -import space.kscience.kmath.stat.samplers.ZigguratNormalizedGaussianSampler import kotlin.math.* /** @@ -16,8 +15,8 @@ public inline class NormalDistribution(public val sampler: GaussianSampler) : Un public constructor( mean: Double, standardDeviation: Double, - normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler.of(), - ) : this(GaussianSampler.of(mean, standardDeviation, normalized)) + normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler, + ) : this(GaussianSampler(mean, standardDeviation, normalized)) public override fun probability(arg: Double): Double { val x1 = (arg - sampler.mean) / sampler.standardDeviation diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalErf.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt similarity index 90% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalErf.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt index 4e1623867..3b9110c1a 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalErf.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalErf.kt @@ -1,4 +1,4 @@ -package space.kscience.kmath.stat.internal +package space.kscience.kmath.internal import kotlin.math.abs diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalGamma.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt similarity index 99% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalGamma.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt index 4f5adbe97..96f5c66db 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalGamma.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalGamma.kt @@ -1,4 +1,4 @@ -package space.kscience.kmath.stat.internal +package space.kscience.kmath.internal import kotlin.math.* diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalUtils.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt similarity index 98% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalUtils.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt index 722eee946..832689b27 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/internal/InternalUtils.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/internal/InternalUtils.kt @@ -1,4 +1,4 @@ -package space.kscience.kmath.stat.internal +package space.kscience.kmath.internal import kotlin.math.ln import kotlin.math.min diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt new file mode 100644 index 000000000..0b8ecfb31 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterExponentialSampler.kt @@ -0,0 +1,72 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler +import space.kscience.kmath.structures.DoubleBuffer +import kotlin.math.ln +import kotlin.math.pow + +/** + * Sampling from an [exponential distribution](http://mathworld.wolfram.com/ExponentialDistribution.html). + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html]. + */ +public class AhrensDieterExponentialSampler(public val mean: Double) : Sampler { + + init { + require(mean > 0) { "mean is not strictly positive: $mean" } + } + + public override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { + override fun nextBlocking(): Double { + // Step 1: + var a = 0.0 + var u = generator.nextDouble() + + // Step 2 and 3: + while (u < 0.5) { + a += EXPONENTIAL_SA_QI[0] + u *= 2.0 + } + + // Step 4 (now u >= 0.5): + u += u - 1 + // Step 5: + if (u <= EXPONENTIAL_SA_QI[0]) return mean * (a + u) + // Step 6: + var i = 0 // Should be 1, be we iterate before it in while using 0. + var u2 = generator.nextDouble() + var umin = u2 + + // Step 7 and 8: + do { + ++i + u2 = generator.nextDouble() + if (u2 < umin) umin = u2 + // Step 8: + } while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1. + + return mean * (a + umin * EXPONENTIAL_SA_QI[0]) + } + + override fun nextBufferBlocking(size: Int): DoubleBuffer = DoubleBuffer(size) { nextBlocking() } + + override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) + } + + public companion object { + private val EXPONENTIAL_SA_QI by lazy { + val ln2 = ln(2.0) + var qi = 0.0 + + DoubleArray(16) { i -> + qi += ln2.pow(i + 1.0) / space.kscience.kmath.internal.InternalUtils.factorial(i + 1) + qi + } + } + } + +} + diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt similarity index 97% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt index 81182f6cd..c8a49106b 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AhrensDieterMarsagliaTsangGammaSampler.kt @@ -1,4 +1,4 @@ -package space.kscience.kmath.stat.samplers +package space.kscience.kmath.samplers import space.kscience.kmath.chains.Chain import space.kscience.kmath.stat.RandomGenerator @@ -80,7 +80,7 @@ public class AhrensDieterMarsagliaTsangGammaSampler private constructor( private val gaussian: NormalizedGaussianSampler init { - gaussian = ZigguratNormalizedGaussianSampler.of() + gaussian = ZigguratNormalizedGaussianSampler dOptim = alpha - ONE_THIRD cOptim = ONE_THIRD / sqrt(dOptim) } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AliasMethodDiscreteSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt similarity index 58% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AliasMethodDiscreteSampler.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt index cae97db65..fe670a4e4 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AliasMethodDiscreteSampler.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/AliasMethodDiscreteSampler.kt @@ -1,10 +1,10 @@ -package space.kscience.kmath.stat.samplers +package space.kscience.kmath.samplers import space.kscience.kmath.chains.Chain +import space.kscience.kmath.internal.InternalUtils import space.kscience.kmath.stat.RandomGenerator import space.kscience.kmath.stat.Sampler import space.kscience.kmath.stat.chain -import space.kscience.kmath.stat.internal.InternalUtils import kotlin.math.ceil import kotlin.math.max import kotlin.math.min @@ -39,12 +39,12 @@ import kotlin.math.min public open class AliasMethodDiscreteSampler private constructor( // Deliberate direct storage of input arrays protected val probability: LongArray, - protected val alias: IntArray + protected val alias: IntArray, ) : Sampler { private class SmallTableAliasMethodDiscreteSampler( probability: LongArray, - alias: IntArray + alias: IntArray, ) : AliasMethodDiscreteSampler(probability, alias) { // Assume the table size is a power of 2 and create the mask private val mask: Int = alias.size - 1 @@ -111,110 +111,6 @@ public open class AliasMethodDiscreteSampler private constructor( private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble() private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11 - public fun of( - probabilities: DoubleArray, - alpha: Int = DEFAULT_ALPHA - ): Sampler { - // The Alias method balances N categories with counts around the mean into N sections, - // each allocated 'mean' observations. - // - // Consider 4 categories with counts 6,3,2,1. The histogram can be balanced into a - // 2D array as 4 sections with a height of the mean: - // - // 6 - // 6 - // 6 - // 63 => 6366 -- - // 632 6326 |-- mean - // 6321 6321 -- - // - // section abcd - // - // Each section is divided as: - // a: 6=1/1 - // b: 3=1/1 - // c: 2=2/3; 6=1/3 (6 is the alias) - // d: 1=1/3; 6=2/3 (6 is the alias) - // - // The sample is obtained by randomly selecting a section, then choosing which category - // from the pair based on a uniform random deviate. - val sumProb = InternalUtils.validateProbabilities(probabilities) - // Allow zero-padding - val n = computeSize(probabilities.size, alpha) - // Partition into small and large by splitting on the average. - val mean = sumProb / n - // The cardinality of smallSize + largeSize = n. - // So fill the same array from either end. - val indices = IntArray(n) - var large = n - var small = 0 - - probabilities.indices.forEach { i -> - if (probabilities[i] >= mean) indices[--large] = i else indices[small++] = i - } - - small = fillRemainingIndices(probabilities.size, indices, small) - // This may be smaller than the input length if the probabilities were already padded. - val nonZeroIndex = findLastNonZeroIndex(probabilities) - // The probabilities are modified so use a copy. - // Note: probabilities are required only up to last nonZeroIndex - val remainingProbabilities = probabilities.copyOf(nonZeroIndex + 1) - // Allocate the final tables. - // Probability table may be truncated (when zero padded). - // The alias table is full length. - val probability = LongArray(remainingProbabilities.size) - val alias = IntArray(n) - - // This loop uses each large in turn to fill the alias table for small probabilities that - // do not reach the requirement to fill an entire section alone (i.e. p < mean). - // Since the sum of the small should be less than the sum of the large it should use up - // all the small first. However floating point round-off can result in - // misclassification of items as small or large. The Vose algorithm handles this using - // a while loop conditioned on the size of both sets and a subsequent loop to use - // unpaired items. - while (large != n && small != 0) { - // Index of the small and the large probabilities. - val j = indices[--small] - val k = indices[large++] - - // Optimisation for zero-padded input: - // p(j) = 0 above the last nonZeroIndex - if (j > nonZeroIndex) - // The entire amount for the section is taken from the alias. - remainingProbabilities[k] -= mean - else { - val pj = remainingProbabilities[j] - // Item j is a small probability that is below the mean. - // Compute the weight of the section for item j: pj / mean. - // This is scaled by 2^53 and the ceiling function used to round-up - // the probability to a numerator of a fraction in the range [1,2^53]. - // Ceiling ensures non-zero values. - probability[j] = ceil(CONVERT_TO_NUMERATOR * (pj / mean)).toLong() - // The remaining amount for the section is taken from the alias. - // Effectively: probabilities[k] -= (mean - pj) - remainingProbabilities[k] += pj - mean - } - - // If not j then the alias is k - alias[j] = k - - // Add the remaining probability from large to the appropriate list. - if (remainingProbabilities[k] >= mean) indices[--large] = k else indices[small++] = k - } - - // Final loop conditions to consume unpaired items. - // Note: The large set should never be non-empty but this can occur due to round-off - // error so consume from both. - fillTable(probability, alias, indices, 0, small) - fillTable(probability, alias, indices, large, n) - - // Change the algorithm for small power of 2 sized tables - return if (isSmallPowerOf2(n)) - SmallTableAliasMethodDiscreteSampler(probability, alias) - else - AliasMethodDiscreteSampler(probability, alias) - } - private fun fillRemainingIndices(length: Int, indices: IntArray, small: Int): Int { var updatedSmall = small (length until indices.size).forEach { i -> indices[updatedSmall++] = i } @@ -246,7 +142,7 @@ public open class AliasMethodDiscreteSampler private constructor( alias: IntArray, indices: IntArray, start: Int, - end: Int + end: Int, ) = (start until end).forEach { i -> val index = indices[i] probability[index] = ONE_AS_NUMERATOR @@ -283,4 +179,110 @@ public open class AliasMethodDiscreteSampler private constructor( return n - (mutI ushr 1) } } + + @Suppress("FunctionName") + public fun AliasMethodDiscreteSampler( + probabilities: DoubleArray, + alpha: Int = DEFAULT_ALPHA, + ): Sampler { + // The Alias method balances N categories with counts around the mean into N sections, + // each allocated 'mean' observations. + // + // Consider 4 categories with counts 6,3,2,1. The histogram can be balanced into a + // 2D array as 4 sections with a height of the mean: + // + // 6 + // 6 + // 6 + // 63 => 6366 -- + // 632 6326 |-- mean + // 6321 6321 -- + // + // section abcd + // + // Each section is divided as: + // a: 6=1/1 + // b: 3=1/1 + // c: 2=2/3; 6=1/3 (6 is the alias) + // d: 1=1/3; 6=2/3 (6 is the alias) + // + // The sample is obtained by randomly selecting a section, then choosing which category + // from the pair based on a uniform random deviate. + val sumProb = InternalUtils.validateProbabilities(probabilities) + // Allow zero-padding + val n = computeSize(probabilities.size, alpha) + // Partition into small and large by splitting on the average. + val mean = sumProb / n + // The cardinality of smallSize + largeSize = n. + // So fill the same array from either end. + val indices = IntArray(n) + var large = n + var small = 0 + + probabilities.indices.forEach { i -> + if (probabilities[i] >= mean) indices[--large] = i else indices[small++] = i + } + + small = fillRemainingIndices(probabilities.size, indices, small) + // This may be smaller than the input length if the probabilities were already padded. + val nonZeroIndex = findLastNonZeroIndex(probabilities) + // The probabilities are modified so use a copy. + // Note: probabilities are required only up to last nonZeroIndex + val remainingProbabilities = probabilities.copyOf(nonZeroIndex + 1) + // Allocate the final tables. + // Probability table may be truncated (when zero padded). + // The alias table is full length. + val probability = LongArray(remainingProbabilities.size) + val alias = IntArray(n) + + // This loop uses each large in turn to fill the alias table for small probabilities that + // do not reach the requirement to fill an entire section alone (i.e. p < mean). + // Since the sum of the small should be less than the sum of the large it should use up + // all the small first. However floating point round-off can result in + // misclassification of items as small or large. The Vose algorithm handles this using + // a while loop conditioned on the size of both sets and a subsequent loop to use + // unpaired items. + while (large != n && small != 0) { + // Index of the small and the large probabilities. + val j = indices[--small] + val k = indices[large++] + + // Optimisation for zero-padded input: + // p(j) = 0 above the last nonZeroIndex + if (j > nonZeroIndex) + // The entire amount for the section is taken from the alias. + remainingProbabilities[k] -= mean + else { + val pj = remainingProbabilities[j] + // Item j is a small probability that is below the mean. + // Compute the weight of the section for item j: pj / mean. + // This is scaled by 2^53 and the ceiling function used to round-up + // the probability to a numerator of a fraction in the range [1,2^53]. + // Ceiling ensures non-zero values. + probability[j] = ceil(CONVERT_TO_NUMERATOR * (pj / mean)).toLong() + // The remaining amount for the section is taken from the alias. + // Effectively: probabilities[k] -= (mean - pj) + remainingProbabilities[k] += pj - mean + } + + // If not j then the alias is k + alias[j] = k + + // Add the remaining probability from large to the appropriate list. + if (remainingProbabilities[k] >= mean) indices[--large] = k else indices[small++] = k + } + + // Final loop conditions to consume unpaired items. + // Note: The large set should never be non-empty but this can occur due to round-off + // error so consume from both. + fillTable(probability, alias, indices, 0, small) + fillTable(probability, alias, indices, large, n) + + // Change the algorithm for small power of 2 sized tables + return if (isSmallPowerOf2(n)) { + SmallTableAliasMethodDiscreteSampler(probability, alias) + } else { + AliasMethodDiscreteSampler(probability, alias) + } + } } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt new file mode 100644 index 000000000..1f1871cbb --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/BoxMullerSampler.kt @@ -0,0 +1,52 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.structures.DoubleBuffer +import kotlin.math.* + +/** + * [Box-Muller algorithm](https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform) for sampling from a Gaussian + * distribution. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html]. + */ + +public object BoxMullerSampler : NormalizedGaussianSampler { + override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { + var state = Double.NaN + + override fun nextBufferBlocking(size: Int): DoubleBuffer { + val xs = generator.nextDoubleBuffer(size) + val ys = generator.nextDoubleBuffer(size) + + return DoubleBuffer(size) { index -> + if (state.isNaN()) { + // Generate a pair of Gaussian numbers. + val x = xs[index] + val y = ys[index] + val alpha = 2 * PI * x + val r = sqrt(-2 * ln(y)) + + // Keep second element of the pair for next invocation. + state = r * sin(alpha) + + // Return the first element of the generated pair. + r * cos(alpha) + } else { + // Use the second element of the pair (generated at the + // previous invocation). + state.also { + // Both elements of the pair have been used. + state = Double.NaN + } + } + } + } + + + override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) + } + +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt new file mode 100644 index 000000000..0f8d13305 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ConstantSampler.kt @@ -0,0 +1,13 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingBufferChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler +import space.kscience.kmath.structures.Buffer + +public class ConstantSampler(public val const: T) : Sampler { + override fun sample(generator: RandomGenerator): BlockingBufferChain = object : BlockingBufferChain { + override fun nextBufferBlocking(size: Int): Buffer = Buffer.boxing(size) { const } + override suspend fun fork(): BlockingBufferChain = this + } +} \ No newline at end of file diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt new file mode 100644 index 000000000..26047830c --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/GaussianSampler.kt @@ -0,0 +1,34 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.chains.map +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler + +/** + * Sampling from a Gaussian distribution with given mean and standard deviation. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html]. + * + * @property mean the mean of the distribution. + * @property standardDeviation the variance of the distribution. + */ +public class GaussianSampler( + public val mean: Double, + public val standardDeviation: Double, + private val normalized: NormalizedGaussianSampler = BoxMullerSampler +) : Sampler { + + init { + require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" } + } + + public override fun sample(generator: RandomGenerator): BlockingDoubleChain = normalized + .sample(generator) + .map { standardDeviation * it + mean } + + override fun toString(): String = "N($mean, $standardDeviation)" + + public companion object +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt new file mode 100644 index 000000000..0f8e6b089 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/KempSmallMeanPoissonSampler.kt @@ -0,0 +1,68 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingIntChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler +import space.kscience.kmath.structures.IntBuffer +import kotlin.math.exp + +/** + * Sampler for the Poisson distribution. + * - Kemp, A, W, (1981) Efficient Generation of Logarithmically Distributed Pseudo-Random Variables. Journal of the Royal Statistical Society. Vol. 30, No. 3, pp. 249-253. + * This sampler is suitable for mean < 40. For large means, LargeMeanPoissonSampler should be used instead. + * + * Note: The algorithm uses a recurrence relation to compute the Poisson probability and a rolling summation for the cumulative probability. When the mean is large the initial probability (Math.exp(-mean)) is zero and an exception is raised by the constructor. + * + * Sampling uses 1 call to UniformRandomProvider.nextDouble(). This method provides an alternative to the SmallMeanPoissonSampler for slow generators of double. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html]. + */ +public class KempSmallMeanPoissonSampler internal constructor( + private val p0: Double, + private val mean: Double, +) : Sampler { + public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { + override fun nextBlocking(): Int { + //TODO move to nextBufferBlocking + // Note on the algorithm: + // - X is the unknown sample deviate (the output of the algorithm) + // - x is the current value from the distribution + // - p is the probability of the current value x, p(X=x) + // - u is effectively the cumulative probability that the sample X + // is equal or above the current value x, p(X>=x) + // So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x + var u = generator.nextDouble() + var x = 0 + var p = p0 + + while (u > p) { + u -= p + // Compute the next probability using a recurrence relation. + // p(x+1) = p(x) * mean / (x+1) + p *= mean / ++x + // The algorithm listed in Kemp (1981) does not check that the rolling probability + // is positive. This check is added to ensure no errors when the limit of the summation + // 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic. + if (p == 0.0) return x + } + + return x + } + + override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() } + + override suspend fun fork(): BlockingIntChain = sample(generator.fork()) + } + + public override fun toString(): String = "Kemp Small Mean Poisson deviate" +} + +public fun KempSmallMeanPoissonSampler(mean: Double): KempSmallMeanPoissonSampler { + require(mean > 0) { "Mean is not strictly positive: $mean" } + val p0 = exp(-mean) + // Probability must be positive. As mean increases then p(0) decreases. + require(p0 > 0) { "No probability for mean: $mean" } + return KempSmallMeanPoissonSampler(p0, mean) +} + diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt new file mode 100644 index 000000000..b93bcc106 --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/MarsagliaNormalizedGaussianSampler.kt @@ -0,0 +1,61 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.structures.DoubleBuffer +import kotlin.math.ln +import kotlin.math.sqrt + +/** + * [Marsaglia polar method](https://en.wikipedia.org/wiki/Marsaglia_polar_method) for sampling from a Gaussian + * distribution with mean 0 and standard deviation 1. This is a variation of the algorithm implemented in + * [BoxMullerNormalizedGaussianSampler]. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html] + */ +public object MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler { + + override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { + var nextGaussian = Double.NaN + + override fun nextBlocking(): Double { + return if (nextGaussian.isNaN()) { + val alpha: Double + var x: Double + + // Rejection scheme for selecting a pair that lies within the unit circle. + while (true) { + // Generate a pair of numbers within [-1 , 1). + x = 2.0 * generator.nextDouble() - 1.0 + val y = 2.0 * generator.nextDouble() - 1.0 + val r2 = x * x + y * y + + if (r2 < 1 && r2 > 0) { + // Pair (x, y) is within unit circle. + alpha = sqrt(-2 * ln(r2) / r2) + // Keep second element of the pair for next invocation. + nextGaussian = alpha * y + // Return the first element of the generated pair. + break + } + // Pair is not within the unit circle: Generate another one. + } + + // Return the first element of the generated pair. + alpha * x + } else { + // Use the second element of the pair (generated at the + // previous invocation). + val r = nextGaussian + // Both elements of the pair have been used. + nextGaussian = Double.NaN + r + } + } + + override fun nextBufferBlocking(size: Int): DoubleBuffer = DoubleBuffer(size) { nextBlocking() } + + override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) + } +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt new file mode 100644 index 000000000..6d3daadab --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/NormalizedGaussianSampler.kt @@ -0,0 +1,18 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler + +public interface BlockingDoubleSampler: Sampler{ + override fun sample(generator: RandomGenerator): BlockingDoubleChain +} + + +/** + * Marker interface for a sampler that generates values from an N(0,1) + * [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution). + */ +public fun interface NormalizedGaussianSampler : BlockingDoubleSampler{ + public companion object +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt new file mode 100644 index 000000000..c2e8e2c1c --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/PoissonSampler.kt @@ -0,0 +1,203 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingIntChain +import space.kscience.kmath.internal.InternalUtils +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.stat.Sampler +import space.kscience.kmath.structures.IntBuffer +import kotlin.math.* + + +private const val PIVOT = 40.0 + +/** + * Sampler for the Poisson distribution. + * - For small means, a Poisson process is simulated using uniform deviates, as described in + * Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 + * Important integer-valued distributions: The Poisson distribution. Addison Wesley. + * The Poisson process (and hence, the returned value) is bounded by 1000 * mean. + * - For large means, we use the rejection algorithm described in + * Devroye, Luc. (1981). The Computer Generation of Poisson Random Variables Computing vol. 26 pp. 197-207. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html]. + */ +@Suppress("FunctionName") +public fun PoissonSampler(mean: Double): Sampler { + return if (mean < PIVOT) SmallMeanPoissonSampler(mean) else LargeMeanPoissonSampler(mean) +} + +/** + * Sampler for the Poisson distribution. + * - For small means, a Poisson process is simulated using uniform deviates, as described in + * Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 Important + * integer-valued distributions: The Poisson distribution. Addison Wesley. + * - The Poisson process (and hence, the returned value) is bounded by 1000 * mean. + * This sampler is suitable for mean < 40. For large means, [LargeMeanPoissonSampler] should be used instead. + * + * Based on Commons RNG implementation. + * + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html]. + */ +public class SmallMeanPoissonSampler(public val mean: Double) : Sampler { + + init { + require(mean > 0) { "mean is not strictly positive: $mean" } + } + + private val p0: Double = exp(-mean) + + private val limit: Int = if (p0 > 0) { + ceil(1000 * mean) + } else { + throw IllegalArgumentException("No p(x=0) probability for mean: $mean") + }.toInt() + + public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { + override fun nextBlocking(): Int { + var n = 0 + var r = 1.0 + + while (n < limit) { + r *= generator.nextDouble() + if (r >= p0) n++ else break + } + + return n + } + + override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() } + + override suspend fun fork(): BlockingIntChain = sample(generator.fork()) + } + + public override fun toString(): String = "Small Mean Poisson deviate" +} + + +/** + * Sampler for the Poisson distribution. + * - For large means, we use the rejection algorithm described in + * Devroye, Luc. (1981).The Computer Generation of Poisson Random Variables + * Computing vol. 26 pp. 197-207. + * + * This sampler is suitable for mean >= 40. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html]. + */ +public class LargeMeanPoissonSampler(public val mean: Double) : Sampler { + + init { + require(mean >= 1) { "mean is not >= 1: $mean" } + // The algorithm is not valid if Math.floor(mean) is not an integer. + require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } + } + + private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG + private val lambda: Double = floor(mean) + private val logLambda: Double = ln(lambda) + private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt()) + private val delta: Double = sqrt(lambda * ln(32 * lambda / PI + 1)) + private val halfDelta: Double = delta / 2 + private val twolpd: Double = 2 * lambda + delta + private val c1: Double = 1 / (8 * lambda) + private val a1: Double = sqrt(PI * twolpd) * exp(c1) + private val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd) + private val aSum: Double = a1 + a2 + 1 + private val p1: Double = a1 / aSum + private val p2: Double = a2 / aSum + + public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain { + override fun nextBlocking(): Int { + val exponential = AhrensDieterExponentialSampler(1.0).sample(generator) + val gaussian = ZigguratNormalizedGaussianSampler.sample(generator) + + val smallMeanPoissonSampler = if (mean - lambda < Double.MIN_VALUE) { + null + } else { + KempSmallMeanPoissonSampler(mean - lambda).sample(generator) + } + + val y2 = smallMeanPoissonSampler?.nextBlocking() ?: 0 + var x: Double + var y: Double + var v: Double + var a: Int + var t: Double + var qr: Double + var qa: Double + + while (true) { + // Step 1: + val u = generator.nextDouble() + + if (u <= p1) { + // Step 2: + val n = gaussian.nextBlocking() + x = n * sqrt(lambda + halfDelta) - 0.5 + if (x > delta || x < -lambda) continue + y = if (x < 0) floor(x) else ceil(x) + val e = exponential.nextBlocking() + v = -e - 0.5 * n * n + c1 + } else { + // Step 3: + if (u > p1 + p2) { + y = lambda + break + } + + x = delta + twolpd / delta * exponential.nextBlocking() + y = ceil(x) + v = -exponential.nextBlocking() - delta * (x + 1) / twolpd + } + + // The Squeeze Principle + // Step 4.1: + a = if (x < 0) 1 else 0 + t = y * (y + 1) / (2 * lambda) + + // Step 4.2 + if (v < -t && a == 0) { + y += lambda + break + } + + // Step 4.3: + qr = t * ((2 * y + 1) / (6 * lambda) - 1) + qa = qr - t * t / (3 * (lambda + a * (y + 1))) + + // Step 4.4: + if (v < qa) { + y += lambda + break + } + + // Step 4.5: + if (v > qr) continue + + // Step 4.6: + if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) { + y += lambda + break + } + } + + return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt() + } + + override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() } + + override suspend fun fork(): BlockingIntChain = sample(generator.fork()) + } + + private fun getFactorialLog(n: Int): Double = factorialLog.value(n) + public override fun toString(): String = "Large Mean Poisson deviate" + + public companion object { + private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE + private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create() + } +} + + diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt new file mode 100644 index 000000000..70f5c248d --- /dev/null +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/samplers/ZigguratNormalizedGaussianSampler.kt @@ -0,0 +1,88 @@ +package space.kscience.kmath.samplers + +import space.kscience.kmath.chains.BlockingDoubleChain +import space.kscience.kmath.stat.RandomGenerator +import space.kscience.kmath.structures.DoubleBuffer +import kotlin.math.* + +/** + * [Marsaglia and Tsang "Ziggurat"](https://en.wikipedia.org/wiki/Ziggurat_algorithm) method for sampling from a + * Gaussian distribution with mean 0 and standard deviation 1. The algorithm is explained in this paper and this + * implementation has been adapted from the C code provided therein. + * + * Based on Commons RNG implementation. + * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html]. + */ +public object ZigguratNormalizedGaussianSampler : NormalizedGaussianSampler { + + private const val R: Double = 3.442619855899 + private const val ONE_OVER_R: Double = 1 / R + private const val V: Double = 9.91256303526217e-3 + private val MAX: Double = 2.0.pow(63.0) + private val ONE_OVER_MAX: Double = 1.0 / MAX + private const val LEN: Int = 128 + private const val LAST: Int = LEN - 1 + private val K: LongArray = LongArray(LEN) + private val W: DoubleArray = DoubleArray(LEN) + private val F: DoubleArray = DoubleArray(LEN) + + init { + // Filling the tables. + var d = R + var t = d + var fd = gauss(d) + val q = V / fd + K[0] = (d / q * MAX).toLong() + K[1] = 0 + W[0] = q * ONE_OVER_MAX + W[LAST] = d * ONE_OVER_MAX + F[0] = 1.0 + F[LAST] = fd + + (LAST - 1 downTo 1).forEach { i -> + d = sqrt(-2 * ln(V / d + fd)) + fd = gauss(d) + K[i + 1] = (d / t * MAX).toLong() + t = d + F[i] = fd + W[i] = d * ONE_OVER_MAX + } + } + + private fun gauss(x: Double): Double = exp(-0.5 * x * x) + + private fun sampleOne(generator: RandomGenerator): Double { + val j = generator.nextLong() + val i = (j and LAST.toLong()).toInt() + return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i) + } + + override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain { + override fun nextBufferBlocking(size: Int): DoubleBuffer = DoubleBuffer(size) { sampleOne(generator) } + + override suspend fun fork(): BlockingDoubleChain = sample(generator.fork()) + } + + + private fun fix(generator: RandomGenerator, hz: Long, iz: Int): Double { + var x = hz * W[iz] + + return when { + iz == 0 -> { + var y: Double + + do { + y = -ln(generator.nextDouble()) + x = -ln(generator.nextDouble()) * ONE_OVER_R + } while (y + y < x * x) + + val out = R + x + if (hz > 0) out else -out + } + + F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x) -> x + else -> sampleOne(generator) + } + } + +} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt index 2f117a035..9e3e265dc 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomChain.kt @@ -1,8 +1,8 @@ package space.kscience.kmath.stat import space.kscience.kmath.chains.BlockingDoubleChain -import space.kscience.kmath.chains.BlockingIntChain import space.kscience.kmath.chains.Chain +import space.kscience.kmath.structures.DoubleBuffer /** * A possibly stateful chain producing random values. @@ -11,12 +11,24 @@ import space.kscience.kmath.chains.Chain */ public class RandomChain( public val generator: RandomGenerator, - private val gen: suspend RandomGenerator.() -> R + private val gen: suspend RandomGenerator.() -> R, ) : Chain { override suspend fun next(): R = generator.gen() - override fun fork(): Chain = RandomChain(generator.fork(), gen) + override suspend fun fork(): Chain = RandomChain(generator.fork(), gen) +} + +/** + * Create a generic random chain with provided [generator] + */ +public fun RandomGenerator.chain(generator: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, generator) + +/** + * A type-specific double chunk random chain + */ +public class UniformDoubleChain(public val generator: RandomGenerator) : BlockingDoubleChain { + public override fun nextBufferBlocking(size: Int): DoubleBuffer = generator.nextDoubleBuffer(size) + override suspend fun nextBuffer(size: Int): DoubleBuffer = nextBufferBlocking(size) + + override suspend fun fork(): UniformDoubleChain = UniformDoubleChain(generator.fork()) } -public fun RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain = RandomChain(this, gen) -public fun Chain.blocking(): BlockingDoubleChain = object : Chain by this, BlockingDoubleChain {} -public fun Chain.blocking(): BlockingIntChain = object : Chain by this, BlockingIntChain {} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt index bad2334e9..c40513efc 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/RandomGenerator.kt @@ -1,5 +1,6 @@ package space.kscience.kmath.stat +import space.kscience.kmath.structures.DoubleBuffer import kotlin.random.Random /** @@ -16,6 +17,11 @@ public interface RandomGenerator { */ public fun nextDouble(): Double + /** + * A chunk of doubles of given [size] + */ + public fun nextDoubleBuffer(size: Int): DoubleBuffer = DoubleBuffer(size) { nextDouble() } + /** * Gets the next random `Int` from the random number generator. * diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Distribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt similarity index 54% rename from kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Distribution.kt rename to kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt index 095182160..8d024b2b9 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Distribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Sampler.kt @@ -3,16 +3,13 @@ package space.kscience.kmath.stat import kotlinx.coroutines.flow.first import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.collect -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.BufferFactory -import space.kscience.kmath.structures.IntBuffer -import space.kscience.kmath.structures.MutableBuffer +import space.kscience.kmath.structures.* import kotlin.jvm.JvmName /** - * Sampler that generates chains of values of type [T]. + * Sampler that generates chains of values of type [T] in a chain of type [C]. */ -public fun interface Sampler { +public fun interface Sampler { /** * Generates a chain of samples. * @@ -22,39 +19,6 @@ public fun interface Sampler { public fun sample(generator: RandomGenerator): Chain } -/** - * A distribution of typed objects. - */ -public interface Distribution : Sampler { - /** - * A probability value for given argument [arg]. - * For continuous distributions returns PDF - */ - public fun probability(arg: T): Double - - public override fun sample(generator: RandomGenerator): Chain - - /** - * An empty companion. Distribution factories should be written as its extensions - */ - public companion object -} - -public interface UnivariateDistribution> : Distribution { - /** - * Cumulative distribution for ordered parameter (CDF) - */ - public fun cumulative(arg: T): Double -} - -/** - * Compute probability integral in an interval - */ -public fun > UnivariateDistribution.integral(from: T, to: T): Double { - require(to > from) - return cumulative(to) - cumulative(from) -} - /** * Sample a bunch of values */ @@ -71,7 +35,7 @@ public fun Sampler.sampleBuffer( //clear list from previous run tmp.clear() //Fill list - repeat(size) { tmp += chain.next() } + repeat(size) { tmp.add(chain.next()) } //return new buffer with elements from tmp bufferFactory(size) { tmp[it] } } @@ -87,7 +51,7 @@ public suspend fun Sampler.next(generator: RandomGenerator): T = sa */ @JvmName("sampleRealBuffer") public fun Sampler.sampleBuffer(generator: RandomGenerator, size: Int): Chain> = - sampleBuffer(generator, size, MutableBuffer.Companion::double) + sampleBuffer(generator, size, ::DoubleBuffer) /** * Generates [size] integer samples and chunks them into some buffers. diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt index 689182115..67f55aea6 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/Statistic.kt @@ -81,7 +81,7 @@ public class Mean( public companion object { //TODO replace with optimized version which respects overflow - public val real: Mean = Mean(DoubleField) { sum, count -> sum / count } + public val double: Mean = Mean(DoubleField) { sum, count -> sum / count } public val int: Mean = Mean(IntRing) { sum, count -> sum / count } public val long: Mean = Mean(LongRing) { sum, count -> sum / count } } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt index 4fc0905b8..4fc759e0c 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/UniformDistribution.kt @@ -2,6 +2,8 @@ package space.kscience.kmath.stat import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.SimpleChain +import space.kscience.kmath.distributions.Distribution +import space.kscience.kmath.distributions.UnivariateDistribution public class UniformDistribution(public val range: ClosedFloatingPointRange) : UnivariateDistribution { private val length: Double = range.endInclusive - range.start diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterExponentialSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterExponentialSampler.kt deleted file mode 100644 index 504c6b881..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/AhrensDieterExponentialSampler.kt +++ /dev/null @@ -1,73 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import space.kscience.kmath.stat.internal.InternalUtils -import kotlin.math.ln -import kotlin.math.pow - -/** - * Sampling from an [exponential distribution](http://mathworld.wolfram.com/ExponentialDistribution.html). - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html]. - */ -public class AhrensDieterExponentialSampler private constructor(public val mean: Double) : Sampler { - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - // Step 1: - var a = 0.0 - var u = nextDouble() - - // Step 2 and 3: - while (u < 0.5) { - a += EXPONENTIAL_SA_QI[0] - u *= 2.0 - } - - // Step 4 (now u >= 0.5): - u += u - 1 - // Step 5: - if (u <= EXPONENTIAL_SA_QI[0]) return@chain mean * (a + u) - // Step 6: - var i = 0 // Should be 1, be we iterate before it in while using 0. - var u2 = nextDouble() - var umin = u2 - - // Step 7 and 8: - do { - ++i - u2 = nextDouble() - if (u2 < umin) umin = u2 - // Step 8: - } while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1. - - mean * (a + umin * EXPONENTIAL_SA_QI[0]) - } - - override fun toString(): String = "Ahrens-Dieter Exponential deviate" - - public companion object { - private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) } - - init { - /** - * Filling EXPONENTIAL_SA_QI table. - * Note that we don't want qi = 0 in the table. - */ - val ln2 = ln(2.0) - var qi = 0.0 - - EXPONENTIAL_SA_QI.indices.forEach { i -> - qi += ln2.pow(i + 1.0) / InternalUtils.factorial(i + 1) - EXPONENTIAL_SA_QI[i] = qi - } - } - - public fun of(mean: Double): AhrensDieterExponentialSampler { - require(mean > 0) { "mean is not strictly positive: $mean" } - return AhrensDieterExponentialSampler(mean) - } - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/BoxMullerNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/BoxMullerNormalizedGaussianSampler.kt deleted file mode 100644 index 04beb448d..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/BoxMullerNormalizedGaussianSampler.kt +++ /dev/null @@ -1,48 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import kotlin.math.* - -/** - * [Box-Muller algorithm](https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform) for sampling from a Gaussian - * distribution. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html]. - */ -public class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler { - private var nextGaussian: Double = Double.NaN - - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - val random: Double - - if (nextGaussian.isNaN()) { - // Generate a pair of Gaussian numbers. - val x = nextDouble() - val y = nextDouble() - val alpha = 2 * PI * x - val r = sqrt(-2 * ln(y)) - // Return the first element of the generated pair. - random = r * cos(alpha) - // Keep second element of the pair for next invocation. - nextGaussian = r * sin(alpha) - } else { - // Use the second element of the pair (generated at the - // previous invocation). - random = nextGaussian - // Both elements of the pair have been used. - nextGaussian = Double.NaN - } - - random - } - - public override fun toString(): String = "Box-Muller normalized Gaussian deviate" - - public companion object { - public fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler() - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/GaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/GaussianSampler.kt deleted file mode 100644 index eba26cfb5..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/GaussianSampler.kt +++ /dev/null @@ -1,43 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.chains.map -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler - -/** - * Sampling from a Gaussian distribution with given mean and standard deviation. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html]. - * - * @property mean the mean of the distribution. - * @property standardDeviation the variance of the distribution. - */ -public class GaussianSampler private constructor( - public val mean: Double, - public val standardDeviation: Double, - private val normalized: NormalizedGaussianSampler -) : Sampler { - public override fun sample(generator: RandomGenerator): Chain = normalized - .sample(generator) - .map { standardDeviation * it + mean } - - override fun toString(): String = "Gaussian deviate [$normalized]" - - public companion object { - public fun of( - mean: Double, - standardDeviation: Double, - normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler.of() - ): GaussianSampler { - require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" } - - return GaussianSampler( - mean, - standardDeviation, - normalized - ) - } - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/KempSmallMeanPoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/KempSmallMeanPoissonSampler.kt deleted file mode 100644 index 1d7f90023..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/KempSmallMeanPoissonSampler.kt +++ /dev/null @@ -1,63 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import kotlin.math.exp - -/** - * Sampler for the Poisson distribution. - * - Kemp, A, W, (1981) Efficient Generation of Logarithmically Distributed Pseudo-Random Variables. Journal of the Royal Statistical Society. Vol. 30, No. 3, pp. 249-253. - * This sampler is suitable for mean < 40. For large means, LargeMeanPoissonSampler should be used instead. - * - * Note: The algorithm uses a recurrence relation to compute the Poisson probability and a rolling summation for the cumulative probability. When the mean is large the initial probability (Math.exp(-mean)) is zero and an exception is raised by the constructor. - * - * Sampling uses 1 call to UniformRandomProvider.nextDouble(). This method provides an alternative to the SmallMeanPoissonSampler for slow generators of double. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html]. - */ -public class KempSmallMeanPoissonSampler private constructor( - private val p0: Double, - private val mean: Double -) : Sampler { - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - // Note on the algorithm: - // - X is the unknown sample deviate (the output of the algorithm) - // - x is the current value from the distribution - // - p is the probability of the current value x, p(X=x) - // - u is effectively the cumulative probability that the sample X - // is equal or above the current value x, p(X>=x) - // So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x - var u = nextDouble() - var x = 0 - var p = p0 - - while (u > p) { - u -= p - // Compute the next probability using a recurrence relation. - // p(x+1) = p(x) * mean / (x+1) - p *= mean / ++x - // The algorithm listed in Kemp (1981) does not check that the rolling probability - // is positive. This check is added to ensure no errors when the limit of the summation - // 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic. - if (p == 0.0) return@chain x - } - - x - } - - public override fun toString(): String = "Kemp Small Mean Poisson deviate" - - public companion object { - public fun of(mean: Double): KempSmallMeanPoissonSampler { - require(mean > 0) { "Mean is not strictly positive: $mean" } - val p0 = exp(-mean) - // Probability must be positive. As mean increases then p(0) decreases. - require(p0 > 0) { "No probability for mean: $mean" } - return KempSmallMeanPoissonSampler(p0, mean) - } - } -} - diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/LargeMeanPoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/LargeMeanPoissonSampler.kt deleted file mode 100644 index de1e7cc89..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/LargeMeanPoissonSampler.kt +++ /dev/null @@ -1,130 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.chains.ConstantChain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import space.kscience.kmath.stat.internal.InternalUtils -import space.kscience.kmath.stat.next -import kotlin.math.* - -/** - * Sampler for the Poisson distribution. - * - For large means, we use the rejection algorithm described in - * Devroye, Luc. (1981).The Computer Generation of Poisson Random Variables - * Computing vol. 26 pp. 197-207. - * - * This sampler is suitable for mean >= 40. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html]. - */ -public class LargeMeanPoissonSampler private constructor(public val mean: Double) : Sampler { - private val exponential: Sampler = AhrensDieterExponentialSampler.of(1.0) - private val gaussian: Sampler = ZigguratNormalizedGaussianSampler.of() - private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG - private val lambda: Double = floor(mean) - private val logLambda: Double = ln(lambda) - private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt()) - private val delta: Double = sqrt(lambda * ln(32 * lambda / PI + 1)) - private val halfDelta: Double = delta / 2 - private val twolpd: Double = 2 * lambda + delta - private val c1: Double = 1 / (8 * lambda) - private val a1: Double = sqrt(PI * twolpd) * exp(c1) - private val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd) - private val aSum: Double = a1 + a2 + 1 - private val p1: Double = a1 / aSum - private val p2: Double = a2 / aSum - - private val smallMeanPoissonSampler: Sampler = if (mean - lambda < Double.MIN_VALUE) - NO_SMALL_MEAN_POISSON_SAMPLER - else // Not used. - KempSmallMeanPoissonSampler.of(mean - lambda) - - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - // This will never be null. It may be a no-op delegate that returns zero. - val y2 = smallMeanPoissonSampler.next(generator) - var x: Double - var y: Double - var v: Double - var a: Int - var t: Double - var qr: Double - var qa: Double - - while (true) { - // Step 1: - val u = generator.nextDouble() - - if (u <= p1) { - // Step 2: - val n = gaussian.next(generator) - x = n * sqrt(lambda + halfDelta) - 0.5 - if (x > delta || x < -lambda) continue - y = if (x < 0) floor(x) else ceil(x) - val e = exponential.next(generator) - v = -e - 0.5 * n * n + c1 - } else { - // Step 3: - if (u > p1 + p2) { - y = lambda - break - } - - x = delta + twolpd / delta * exponential.next(generator) - y = ceil(x) - v = -exponential.next(generator) - delta * (x + 1) / twolpd - } - - // The Squeeze Principle - // Step 4.1: - a = if (x < 0) 1 else 0 - t = y * (y + 1) / (2 * lambda) - - // Step 4.2 - if (v < -t && a == 0) { - y += lambda - break - } - - // Step 4.3: - qr = t * ((2 * y + 1) / (6 * lambda) - 1) - qa = qr - t * t / (3 * (lambda + a * (y + 1))) - - // Step 4.4: - if (v < qa) { - y += lambda - break - } - - // Step 4.5: - if (v > qr) continue - - // Step 4.6: - if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) { - y += lambda - break - } - } - - min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt() - } - - private fun getFactorialLog(n: Int): Double = factorialLog.value(n) - public override fun toString(): String = "Large Mean Poisson deviate" - - public companion object { - private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE - private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create() - - private val NO_SMALL_MEAN_POISSON_SAMPLER: Sampler = Sampler { ConstantChain(0) } - - public fun of(mean: Double): LargeMeanPoissonSampler { - require(mean >= 1) { "mean is not >= 1: $mean" } - // The algorithm is not valid if Math.floor(mean) is not an integer. - require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" } - return LargeMeanPoissonSampler(mean) - } - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/MarsagliaNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/MarsagliaNormalizedGaussianSampler.kt deleted file mode 100644 index 8a659642f..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/MarsagliaNormalizedGaussianSampler.kt +++ /dev/null @@ -1,61 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import kotlin.math.ln -import kotlin.math.sqrt - -/** - * [Marsaglia polar method](https://en.wikipedia.org/wiki/Marsaglia_polar_method) for sampling from a Gaussian - * distribution with mean 0 and standard deviation 1. This is a variation of the algorithm implemented in - * [BoxMullerNormalizedGaussianSampler]. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html] - */ -public class MarsagliaNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler { - private var nextGaussian = Double.NaN - - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - if (nextGaussian.isNaN()) { - val alpha: Double - var x: Double - - // Rejection scheme for selecting a pair that lies within the unit circle. - while (true) { - // Generate a pair of numbers within [-1 , 1). - x = 2.0 * generator.nextDouble() - 1.0 - val y = 2.0 * generator.nextDouble() - 1.0 - val r2 = x * x + y * y - - if (r2 < 1 && r2 > 0) { - // Pair (x, y) is within unit circle. - alpha = sqrt(-2 * ln(r2) / r2) - // Keep second element of the pair for next invocation. - nextGaussian = alpha * y - // Return the first element of the generated pair. - break - } - // Pair is not within the unit circle: Generate another one. - } - - // Return the first element of the generated pair. - alpha * x - } else { - // Use the second element of the pair (generated at the - // previous invocation). - val r = nextGaussian - // Both elements of the pair have been used. - nextGaussian = Double.NaN - r - } - } - - public override fun toString(): String = "Box-Muller (with rejection) normalized Gaussian deviate" - - public companion object { - public fun of(): MarsagliaNormalizedGaussianSampler = MarsagliaNormalizedGaussianSampler() - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/NormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/NormalizedGaussianSampler.kt deleted file mode 100644 index 4eb3d60e0..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/NormalizedGaussianSampler.kt +++ /dev/null @@ -1,9 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.stat.Sampler - -/** - * Marker interface for a sampler that generates values from an N(0,1) - * [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution). - */ -public interface NormalizedGaussianSampler : Sampler diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/PoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/PoissonSampler.kt deleted file mode 100644 index 0c0234892..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/PoissonSampler.kt +++ /dev/null @@ -1,30 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler - -/** - * Sampler for the Poisson distribution. - * - For small means, a Poisson process is simulated using uniform deviates, as described in - * Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 - * Important integer-valued distributions: The Poisson distribution. Addison Wesley. - * The Poisson process (and hence, the returned value) is bounded by 1000 * mean. - * - For large means, we use the rejection algorithm described in - * Devroye, Luc. (1981). The Computer Generation of Poisson Random Variables Computing vol. 26 pp. 197-207. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html]. - */ -public class PoissonSampler private constructor(mean: Double) : Sampler { - private val poissonSamplerDelegate: Sampler = of(mean) - public override fun sample(generator: RandomGenerator): Chain = poissonSamplerDelegate.sample(generator) - public override fun toString(): String = poissonSamplerDelegate.toString() - - public companion object { - private const val PIVOT = 40.0 - - public fun of(mean: Double): Sampler =// Each sampler should check the input arguments. - if (mean < PIVOT) SmallMeanPoissonSampler.of(mean) else LargeMeanPoissonSampler.of(mean) - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/SmallMeanPoissonSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/SmallMeanPoissonSampler.kt deleted file mode 100644 index 0fe7ff161..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/SmallMeanPoissonSampler.kt +++ /dev/null @@ -1,50 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import kotlin.math.ceil -import kotlin.math.exp - -/** - * Sampler for the Poisson distribution. - * - For small means, a Poisson process is simulated using uniform deviates, as described in - * Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 Important - * integer-valued distributions: The Poisson distribution. Addison Wesley. - * - The Poisson process (and hence, the returned value) is bounded by 1000 * mean. - * This sampler is suitable for mean < 40. For large means, [LargeMeanPoissonSampler] should be used instead. - * - * Based on Commons RNG implementation. - * - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html]. - */ -public class SmallMeanPoissonSampler private constructor(mean: Double) : Sampler { - private val p0: Double = exp(-mean) - - private val limit: Int = (if (p0 > 0) - ceil(1000 * mean) - else - throw IllegalArgumentException("No p(x=0) probability for mean: $mean")).toInt() - - public override fun sample(generator: RandomGenerator): Chain = generator.chain { - var n = 0 - var r = 1.0 - - while (n < limit) { - r *= nextDouble() - if (r >= p0) n++ else break - } - - n - } - - public override fun toString(): String = "Small Mean Poisson deviate" - - public companion object { - public fun of(mean: Double): SmallMeanPoissonSampler { - require(mean > 0) { "mean is not strictly positive: $mean" } - return SmallMeanPoissonSampler(mean) - } - } -} diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/ZigguratNormalizedGaussianSampler.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/ZigguratNormalizedGaussianSampler.kt deleted file mode 100644 index 90815209f..000000000 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/stat/samplers/ZigguratNormalizedGaussianSampler.kt +++ /dev/null @@ -1,88 +0,0 @@ -package space.kscience.kmath.stat.samplers - -import space.kscience.kmath.chains.Chain -import space.kscience.kmath.stat.RandomGenerator -import space.kscience.kmath.stat.Sampler -import space.kscience.kmath.stat.chain -import kotlin.math.* - -/** - * [Marsaglia and Tsang "Ziggurat"](https://en.wikipedia.org/wiki/Ziggurat_algorithm) method for sampling from a - * Gaussian distribution with mean 0 and standard deviation 1. The algorithm is explained in this paper and this - * implementation has been adapted from the C code provided therein. - * - * Based on Commons RNG implementation. - * See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html]. - */ -public class ZigguratNormalizedGaussianSampler private constructor() : - NormalizedGaussianSampler, Sampler { - - private fun sampleOne(generator: RandomGenerator): Double { - val j = generator.nextLong() - val i = (j and LAST.toLong()).toInt() - return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i) - } - - public override fun sample(generator: RandomGenerator): Chain = generator.chain { sampleOne(this) } - public override fun toString(): String = "Ziggurat normalized Gaussian deviate" - - private fun fix(generator: RandomGenerator, hz: Long, iz: Int): Double { - var x = hz * W[iz] - - return when { - iz == 0 -> { - var y: Double - - do { - y = -ln(generator.nextDouble()) - x = -ln(generator.nextDouble()) * ONE_OVER_R - } while (y + y < x * x) - - val out = R + x - if (hz > 0) out else -out - } - - F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x) -> x - else -> sampleOne(generator) - } - } - - public companion object { - private const val R: Double = 3.442619855899 - private const val ONE_OVER_R: Double = 1 / R - private const val V: Double = 9.91256303526217e-3 - private val MAX: Double = 2.0.pow(63.0) - private val ONE_OVER_MAX: Double = 1.0 / MAX - private const val LEN: Int = 128 - private const val LAST: Int = LEN - 1 - private val K: LongArray = LongArray(LEN) - private val W: DoubleArray = DoubleArray(LEN) - private val F: DoubleArray = DoubleArray(LEN) - - init { - // Filling the tables. - var d = R - var t = d - var fd = gauss(d) - val q = V / fd - K[0] = (d / q * MAX).toLong() - K[1] = 0 - W[0] = q * ONE_OVER_MAX - W[LAST] = d * ONE_OVER_MAX - F[0] = 1.0 - F[LAST] = fd - - (LAST - 1 downTo 1).forEach { i -> - d = sqrt(-2 * ln(V / d + fd)) - fd = gauss(d) - K[i + 1] = (d / t * MAX).toLong() - t = d - F[i] = fd - W[i] = d * ONE_OVER_MAX - } - } - - public fun of(): ZigguratNormalizedGaussianSampler = ZigguratNormalizedGaussianSampler() - private fun gauss(x: Double): Double = exp(-0.5 * x * x) - } -} diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt index 76aac65c4..c6b9cb17a 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/CommonsDistributionsTest.kt @@ -5,22 +5,23 @@ import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions import org.junit.jupiter.api.Test -import space.kscience.kmath.stat.samplers.GaussianSampler +import space.kscience.kmath.samplers.GaussianSampler +import space.kscience.kmath.structures.asBuffer internal class CommonsDistributionsTest { @Test - fun testNormalDistributionSuspend() { - val distribution = GaussianSampler.of(7.0, 2.0) + fun testNormalDistributionSuspend() = runBlocking { + val distribution = GaussianSampler(7.0, 2.0) val generator = RandomGenerator.default(1) - val sample = runBlocking { distribution.sample(generator).take(1000).toList() } - Assertions.assertEquals(7.0, sample.average(), 0.1) + val sample = distribution.sample(generator).take(1000).toList().asBuffer() + Assertions.assertEquals(7.0, Mean.double(sample), 0.2) } @Test - fun testNormalDistributionBlocking() { - val distribution = GaussianSampler.of(7.0, 2.0) + fun testNormalDistributionBlocking() = runBlocking { + val distribution = GaussianSampler(7.0, 2.0) val generator = RandomGenerator.default(1) - val sample = runBlocking { distribution.sample(generator).blocking().nextBlock(1000) } - Assertions.assertEquals(7.0, sample.average(), 0.1) + val sample = distribution.sample(generator).nextBufferBlocking(1000) + Assertions.assertEquals(7.0, Mean.double(sample), 0.2) } } diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt index 156e618f9..3c9d6a2e4 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt @@ -20,7 +20,7 @@ internal class StatisticTest { @Test fun testParallelMean() { runBlocking { - val average = Mean.real + val average = Mean.double .flow(chunked) //create a flow with results .drop(99) // Skip first 99 values and use one with total data .first() //get 1e5 data samples average From af4866e8763e48355e8d764a0c519a79feca78d0 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 1 Apr 2021 20:15:49 +0300 Subject: [PATCH 3/9] Refactor MST --- CHANGELOG.md | 4 + .../space/kscience/kmath/ast/expressions.kt | 10 +- .../kscience/kmath/ast/kotlingradSupport.kt | 12 +- .../kotlin/space/kscience/kmath/ast/MST.kt | 47 +++++- .../space/kscience/kmath/ast/MstExpression.kt | 138 ------------------ .../space/kscience/kmath/estree/estree.kt | 25 ++-- .../TestESTreeConsistencyWithInterpreter.kt | 82 ++++------- .../estree/TestESTreeOperationsSupport.kt | 15 +- .../kmath/estree/TestESTreeSpecialization.kt | 30 ++-- .../kmath/estree/TestESTreeVariables.kt | 7 +- .../kotlin/space/kscience/kmath/asm/asm.kt | 31 ++-- .../asm/TestAsmConsistencyWithInterpreter.kt | 82 ++++------- .../kmath/asm/TestAsmOperationsSupport.kt | 17 ++- .../kmath/asm/TestAsmSpecialization.kt | 30 ++-- .../kscience/kmath/asm/TestAsmVariables.kt | 7 +- .../space/kscience/kmath/ast/ParserTest.kt | 4 +- .../kotlingrad/DifferentiableMstExpression.kt | 49 +++---- .../kmath/kotlingrad/AdaptingTests.kt | 16 +- 18 files changed, 241 insertions(+), 365 deletions(-) delete mode 100644 kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstExpression.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ade9cd9c..c4d3b93e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - ScaleOperations interface - Field extends ScaleOperations - Basic integration API +- Basic MPP distributions and samplers ### Changed - Exponential operations merged with hyperbolic functions @@ -14,6 +15,8 @@ - NDStructure and NDAlgebra to StructureND and AlgebraND respectively - Real -> Double - DataSets are moved from functions to core +- Redesign advanced Chain API +- Redesign MST. Remove MSTExpression. ### Deprecated @@ -21,6 +24,7 @@ - Nearest in Domain. To be implemented in geometry package. - Number multiplication and division in main Algebra chain - `contentEquals` from Buffer. It moved to the companion. +- MSTExpression ### Fixed diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt index 17c85eea5..a4b8b7ca4 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/expressions.kt @@ -1,15 +1,17 @@ package space.kscience.kmath.ast -import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.misc.Symbol.Companion.x import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.bindSymbol +import space.kscience.kmath.operations.invoke fun main() { - val expr = DoubleField.mstInField { - val x = bindSymbol("x") + val expr = MstField { + val x = bindSymbol(x) x * 2.0 + number(2.0) / x - 16.0 } repeat(10000000) { - expr.invoke("x" to 1.0) + expr.interpret(DoubleField, x to 1.0) } } \ No newline at end of file diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt index 138b3e708..fb69177a2 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt @@ -1,9 +1,9 @@ package space.kscience.kmath.ast -import space.kscience.kmath.asm.compile +import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.derivative import space.kscience.kmath.expressions.invoke -import space.kscience.kmath.kotlingrad.differentiable +import space.kscience.kmath.kotlingrad.toDiffExpression import space.kscience.kmath.misc.symbol import space.kscience.kmath.operations.DoubleField @@ -14,11 +14,11 @@ import space.kscience.kmath.operations.DoubleField fun main() { val x by symbol - val actualDerivative = MstExpression(DoubleField, "x^2-4*x-44".parseMath()) - .differentiable() + val actualDerivative = "x^2-4*x-44".parseMath() + .toDiffExpression(DoubleField) .derivative(x) - .compile() - val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile() + + val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField) assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0)) } diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt index c459d7ff5..b8c2aadf7 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt @@ -1,5 +1,8 @@ package space.kscience.kmath.ast +import space.kscience.kmath.expressions.Expression +import space.kscience.kmath.misc.StringSymbol +import space.kscience.kmath.misc.Symbol import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.NumericAlgebra @@ -76,11 +79,51 @@ public fun Algebra.evaluate(node: MST): T = when (node) { } } +internal class InnerAlgebra(val algebra: Algebra, val arguments: Map) : NumericAlgebra { + override fun bindSymbol(value: String): T = try { + algebra.bindSymbol(value) + } catch (ignored: IllegalStateException) { + null + } ?: arguments.getValue(StringSymbol(value)) + + override fun unaryOperation(operation: String, arg: T): T = + algebra.unaryOperation(operation, arg) + + override fun binaryOperation(operation: String, left: T, right: T): T = + algebra.binaryOperation(operation, left, right) + + override fun unaryOperationFunction(operation: String): (arg: T) -> T = + algebra.unaryOperationFunction(operation) + + override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = + algebra.binaryOperationFunction(operation) + + @Suppress("UNCHECKED_CAST") + override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) + (algebra as NumericAlgebra).number(value) + else + error("Numeric nodes are not supported by $this") +} + /** - * Interprets the [MST] node with this [Algebra]. + * Interprets the [MST] node with this [Algebra] and optional [arguments] + */ +public fun MST.interpret(algebra: Algebra, arguments: Map): T = + InnerAlgebra(algebra, arguments).evaluate(this) + +/** + * Interprets the [MST] node with this [Algebra] and optional [arguments] * * @receiver the node to evaluate. * @param algebra the algebra that provides operations. * @return the value of expression. */ -public fun MST.interpret(algebra: Algebra): T = algebra.evaluate(this) +public fun MST.interpret(algebra: Algebra, vararg arguments: Pair): T = + interpret(algebra, mapOf(*arguments)) + +/** + * Interpret this [MST] as expression. + */ +public fun MST.toExpression(algebra: Algebra): Expression = Expression { arguments -> + interpret(algebra, arguments) +} diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstExpression.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstExpression.kt deleted file mode 100644 index 5c43df068..000000000 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstExpression.kt +++ /dev/null @@ -1,138 +0,0 @@ -package space.kscience.kmath.ast - -import space.kscience.kmath.expressions.* -import space.kscience.kmath.misc.StringSymbol -import space.kscience.kmath.misc.Symbol -import space.kscience.kmath.operations.* -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract - -/** - * The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than - * ASM-generated expressions. - * - * @property algebra the algebra that provides operations. - * @property mst the [MST] node. - * @author Alexander Nozik - */ -public class MstExpression>(public val algebra: A, public val mst: MST) : Expression { - private inner class InnerAlgebra(val arguments: Map) : NumericAlgebra { - override fun bindSymbol(value: String): T = try { - algebra.bindSymbol(value) - } catch (ignored: IllegalStateException) { - null - } ?: arguments.getValue(StringSymbol(value)) - - override fun unaryOperation(operation: String, arg: T): T = - algebra.unaryOperation(operation, arg) - - override fun binaryOperation(operation: String, left: T, right: T): T = - algebra.binaryOperation(operation, left, right) - - override fun unaryOperationFunction(operation: String): (arg: T) -> T = - algebra.unaryOperationFunction(operation) - - override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = - algebra.binaryOperationFunction(operation) - - @Suppress("UNCHECKED_CAST") - override fun number(value: Number): T = if (algebra is NumericAlgebra<*>) - (algebra as NumericAlgebra).number(value) - else - error("Numeric nodes are not supported by $this") - } - - override operator fun invoke(arguments: Map): T = InnerAlgebra(arguments).evaluate(mst) -} - -/** - * Builds [MstExpression] over [Algebra]. - * - * @author Alexander Nozik - */ -public inline fun , E : Algebra> A.mst( - mstAlgebra: E, - block: E.() -> MST, -): MstExpression = MstExpression(this, mstAlgebra.block()) - -/** - * Builds [MstExpression] over [Group]. - * - * @author Alexander Nozik - */ -public inline fun > A.mstInGroup(block: MstGroup.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstGroup.block()) -} - -/** - * Builds [MstExpression] over [Ring]. - * - * @author Alexander Nozik - */ -public inline fun > A.mstInRing(block: MstRing.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstRing.block()) -} - -/** - * Builds [MstExpression] over [Field]. - * - * @author Alexander Nozik - */ -public inline fun > A.mstInField(block: MstField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstField.block()) -} - -/** - * Builds [MstExpression] over [ExtendedField]. - * - * @author Iaroslav Postovalov - */ -public inline fun > A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return MstExpression(this, MstExtendedField.block()) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionGroup]. - * - * @author Alexander Nozik - */ -public inline fun > FunctionalExpressionGroup.mstInGroup(block: MstGroup.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInGroup(block) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionRing]. - * - * @author Alexander Nozik - */ -public inline fun > FunctionalExpressionRing.mstInRing(block: MstRing.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInRing(block) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionField]. - * - * @author Alexander Nozik - */ -public inline fun > FunctionalExpressionField.mstInField(block: MstField.() -> MST): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInField(block) -} - -/** - * Builds [MstExpression] over [FunctionalExpressionExtendedField]. - * - * @author Iaroslav Postovalov - */ -public inline fun > FunctionalExpressionExtendedField.mstInExtendedField( - block: MstExtendedField.() -> MST, -): MstExpression { - contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } - return algebra.mstInExtendedField(block) -} diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt index 456a2ba07..93b2d54c8 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt @@ -2,10 +2,11 @@ package space.kscience.kmath.estree import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MST.* -import space.kscience.kmath.ast.MstExpression import space.kscience.kmath.estree.internal.ESTreeBuilder import space.kscience.kmath.estree.internal.estree.BaseExpression import space.kscience.kmath.expressions.Expression +import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.misc.Symbol import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.NumericAlgebra @@ -64,19 +65,21 @@ internal fun MST.compileWith(algebra: Algebra): Expression { return ESTreeBuilder { visit(this@compileWith) }.instance } +/** + * Create a compiled expression with given [MST] and given [algebra]. + */ +public fun MST.compileToExpression(algebra: Algebra): Expression = compileWith(algebra) + /** - * Compiles an [MST] to ESTree generated expression using given algebra. - * - * @author Iaroslav Postovalov + * Compile given MST to expression and evaluate it against [arguments] */ -public fun Algebra.expression(mst: MST): Expression = - mst.compileWith(this) +public inline fun MST.compile(algebra: Algebra, arguments: Map): T = + compileToExpression(algebra).invoke(arguments) + /** - * Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression. - * - * @author Iaroslav Postovalov + * Compile given MST to expression and evaluate it against [arguments] */ -public fun MstExpression>.compile(): Expression = - mst.compileWith(algebra) +public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T = + compileToExpression(algebra).invoke(*arguments) diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt index 683c0337c..fb8d73c0c 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeConsistencyWithInterpreter.kt @@ -3,16 +3,19 @@ package space.kscience.kmath.estree import space.kscience.kmath.ast.* import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.complex.toComplex -import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.misc.Symbol import space.kscience.kmath.operations.ByteRing import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals internal class TestESTreeConsistencyWithInterpreter { + @Test fun mstSpace() { - val res1 = MstGroup.mstInGroup { + + val mst = MstGroup { binaryOperationFunction("+")( unaryOperationFunction("+")( number(3.toByte()) - (number(2.toByte()) + (scale( @@ -23,27 +26,17 @@ internal class TestESTreeConsistencyWithInterpreter { number(1) ) + bindSymbol("x") + zero - }("x" to MST.Numeric(2)) + } - val res2 = MstGroup.mstInGroup { - binaryOperationFunction("+")( - unaryOperationFunction("+")( - number(3.toByte()) - (number(2.toByte()) + (scale( - add(number(1), number(1)), - 2.0 - ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) - ), - - number(1) - ) + bindSymbol("x") + zero - }.compile()("x" to MST.Numeric(2)) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(MstGroup, Symbol.x to MST.Numeric(2)), + mst.compile(MstGroup, Symbol.x to MST.Numeric(2)) + ) } @Test fun byteRing() { - val res1 = ByteRing.mstInRing { + val mst = MstRing { binaryOperationFunction("+")( unaryOperationFunction("+")( (bindSymbol("x") - (2.toByte() + (scale( @@ -54,62 +47,43 @@ internal class TestESTreeConsistencyWithInterpreter { number(1) ) * number(2) - }("x" to 3.toByte()) + } - val res2 = ByteRing.mstInRing { - binaryOperationFunction("+")( - unaryOperationFunction("+")( - (bindSymbol("x") - (2.toByte() + (scale( - add(number(1), number(1)), - 2.0 - ) + 1.toByte()))) * 3.0 - 1.toByte() - ), - number(1) - ) * number(2) - }.compile()("x" to 3.toByte()) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(ByteRing, Symbol.x to 3.toByte()), + mst.compile(ByteRing, Symbol.x to 3.toByte()) + ) } @Test fun realField() { - val res1 = DoubleField.mstInField { + val mst = MstField { +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one ) + zero - }("x" to 2.0) + } - val res2 = DoubleField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( - (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 - + number(1), - number(1) / 2 + number(2.0) * one - ) + zero - }.compile()("x" to 2.0) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(DoubleField, Symbol.x to 2.0), + mst.compile(DoubleField, Symbol.x to 2.0) + ) } @Test fun complexField() { - val res1 = ComplexField.mstInField { + val mst = MstField { +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one ) + zero - }("x" to 2.0.toComplex()) + } - val res2 = ComplexField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( - (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 - + number(1), - number(1) / 2 + number(2.0) * one - ) + zero - }.compile()("x" to 2.0.toComplex()) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(ComplexField, Symbol.x to 2.0.toComplex()), + mst.compile(ComplexField, Symbol.x to 2.0.toComplex()) + ) } } diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeOperationsSupport.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeOperationsSupport.kt index d59c048b6..24c003e3e 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeOperationsSupport.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeOperationsSupport.kt @@ -1,10 +1,9 @@ package space.kscience.kmath.estree -import space.kscience.kmath.ast.mstInExtendedField -import space.kscience.kmath.ast.mstInField -import space.kscience.kmath.ast.mstInGroup +import space.kscience.kmath.ast.MstExtendedField import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals @@ -12,29 +11,29 @@ import kotlin.test.assertEquals internal class TestESTreeOperationsSupport { @Test fun testUnaryOperationInvocation() { - val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile() + val expression = MstExtendedField { -bindSymbol("x") }.compileToExpression(DoubleField) val res = expression("x" to 2.0) assertEquals(-2.0, res) } @Test fun testBinaryOperationInvocation() { - val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile() + val expression = MstExtendedField { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField) val res = expression("x" to 2.0) assertEquals(-1.0, res) } @Test fun testConstProductInvocation() { - val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0) + val res = MstExtendedField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0) assertEquals(4.0, res) } @Test fun testMultipleCalls() { val e = - DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) } - .compile() + MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) } + .compileToExpression(DoubleField) val r = Random(0) var s = 0.0 repeat(1000000) { s += e("x" to r.nextDouble()) } diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeSpecialization.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeSpecialization.kt index 6be963175..c83fbc391 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeSpecialization.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeSpecialization.kt @@ -1,53 +1,63 @@ package space.kscience.kmath.estree -import space.kscience.kmath.ast.mstInField +import space.kscience.kmath.ast.MstExtendedField import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals internal class TestESTreeSpecialization { @Test fun testUnaryPlus() { - val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile() + val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField) assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { - val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile() + val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField) assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { - val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile() + val expr = MstExtendedField { + binaryOperationFunction("+")(bindSymbol("x"), + bindSymbol("x")) + }.compileToExpression(DoubleField) assertEquals(4.0, expr("x" to 2.0)) } @Test fun testSine() { - val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile() + val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField) assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { - val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile() + val expr = MstExtendedField { + binaryOperationFunction("-")(bindSymbol("x"), + bindSymbol("x")) + }.compileToExpression(DoubleField) assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { - val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile() + val expr = MstExtendedField { + binaryOperationFunction("/")(bindSymbol("x"), + bindSymbol("x")) + }.compileToExpression(DoubleField) assertEquals(1.0, expr("x" to 2.0)) } @Test fun testPower() { - val expr = DoubleField - .mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) } - .compile() + val expr = MstExtendedField { + binaryOperationFunction("pow")(bindSymbol("x"), number(2)) + }.compileToExpression(DoubleField) assertEquals(4.0, expr("x" to 2.0)) } diff --git a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeVariables.kt b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeVariables.kt index ee8f4c6f5..0b1c1c33e 100644 --- a/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeVariables.kt +++ b/kmath-ast/src/jsTest/kotlin/space/kscience/kmath/estree/TestESTreeVariables.kt @@ -1,8 +1,9 @@ package space.kscience.kmath.estree -import space.kscience.kmath.ast.mstInRing +import space.kscience.kmath.ast.MstRing import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.ByteRing +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith internal class TestESTreeVariables { @Test fun testVariable() { - val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() + val expr = MstRing{ bindSymbol("x") }.compileToExpression(ByteRing) assertEquals(1.toByte(), expr("x" to 1.toByte())) } @Test fun testUndefinedVariableFails() { - val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() + val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing) assertFailsWith { expr() } } } diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 369fe136b..5324d74a1 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -4,8 +4,9 @@ import space.kscience.kmath.asm.internal.AsmBuilder import space.kscience.kmath.asm.internal.buildName import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MST.* -import space.kscience.kmath.ast.MstExpression import space.kscience.kmath.expressions.Expression +import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.misc.Symbol import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.NumericAlgebra @@ -70,18 +71,22 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp return AsmBuilder(type, buildName(this)) { visit(this@compileWith) }.instance } -/** - * Compiles an [MST] to ASM using given algebra. - * - * @author Alexander Nozik - */ -public inline fun Algebra.expression(mst: MST): Expression = - mst.compileWith(T::class.java, this) /** - * Optimizes performance of an [MstExpression] using ASM codegen. - * - * @author Alexander Nozik + * Create a compiled expression with given [MST] and given [algebra]. */ -public inline fun MstExpression>.compile(): Expression = - mst.compileWith(T::class.java, algebra) +public inline fun MST.compileToExpression(algebra: Algebra): Expression = + compileWith(T::class.java, algebra) + + +/** + * Compile given MST to expression and evaluate it against [arguments] + */ +public inline fun MST.compile(algebra: Algebra, arguments: Map): T = + compileToExpression(algebra).invoke(arguments) + +/** + * Compile given MST to expression and evaluate it against [arguments] + */ +public inline fun MST.compile(algebra: Algebra, vararg arguments: Pair): T = + compileToExpression(algebra).invoke(*arguments) diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt index abc320360..096bf4447 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmConsistencyWithInterpreter.kt @@ -3,16 +3,19 @@ package space.kscience.kmath.asm import space.kscience.kmath.ast.* import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.complex.toComplex -import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.misc.Symbol.Companion.x import space.kscience.kmath.operations.ByteRing import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals internal class TestAsmConsistencyWithInterpreter { + @Test fun mstSpace() { - val res1 = MstGroup.mstInGroup { + + val mst = MstGroup { binaryOperationFunction("+")( unaryOperationFunction("+")( number(3.toByte()) - (number(2.toByte()) + (scale( @@ -23,27 +26,17 @@ internal class TestAsmConsistencyWithInterpreter { number(1) ) + bindSymbol("x") + zero - }("x" to MST.Numeric(2)) + } - val res2 = MstGroup.mstInGroup { - binaryOperationFunction("+")( - unaryOperationFunction("+")( - number(3.toByte()) - (number(2.toByte()) + (scale( - add(number(1), number(1)), - 2.0 - ) + number(1.toByte()) * 3.toByte() - number(1.toByte()))) - ), - - number(1) - ) + bindSymbol("x") + zero - }.compile()("x" to MST.Numeric(2)) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(MstGroup, x to MST.Numeric(2)), + mst.compile(MstGroup, x to MST.Numeric(2)) + ) } @Test fun byteRing() { - val res1 = ByteRing.mstInRing { + val mst = MstRing { binaryOperationFunction("+")( unaryOperationFunction("+")( (bindSymbol("x") - (2.toByte() + (scale( @@ -54,62 +47,43 @@ internal class TestAsmConsistencyWithInterpreter { number(1) ) * number(2) - }("x" to 3.toByte()) + } - val res2 = ByteRing.mstInRing { - binaryOperationFunction("+")( - unaryOperationFunction("+")( - (bindSymbol("x") - (2.toByte() + (scale( - add(number(1), number(1)), - 2.0 - ) + 1.toByte()))) * 3.0 - 1.toByte() - ), - number(1) - ) * number(2) - }.compile()("x" to 3.toByte()) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(ByteRing, x to 3.toByte()), + mst.compile(ByteRing, x to 3.toByte()) + ) } @Test fun realField() { - val res1 = DoubleField.mstInField { + val mst = MstField { +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one ) + zero - }("x" to 2.0) + } - val res2 = DoubleField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( - (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 - + number(1), - number(1) / 2 + number(2.0) * one - ) + zero - }.compile()("x" to 2.0) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(DoubleField, x to 2.0), + mst.compile(DoubleField, x to 2.0) + ) } @Test fun complexField() { - val res1 = ComplexField.mstInField { + val mst = MstField { +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 + number(1), number(1) / 2 + number(2.0) * one ) + zero - }("x" to 2.0.toComplex()) + } - val res2 = ComplexField.mstInField { - +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( - (3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 - + number(1), - number(1) / 2 + number(2.0) * one - ) + zero - }.compile()("x" to 2.0.toComplex()) - - assertEquals(res1, res2) + assertEquals( + mst.interpret(ComplexField, x to 2.0.toComplex()), + mst.compile(ComplexField, x to 2.0.toComplex()) + ) } } diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmOperationsSupport.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmOperationsSupport.kt index 5d70cb76b..d1a216ede 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmOperationsSupport.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmOperationsSupport.kt @@ -1,10 +1,11 @@ package space.kscience.kmath.asm -import space.kscience.kmath.ast.mstInExtendedField -import space.kscience.kmath.ast.mstInField -import space.kscience.kmath.ast.mstInGroup +import space.kscience.kmath.ast.MstExtendedField +import space.kscience.kmath.ast.MstField +import space.kscience.kmath.ast.MstGroup import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.random.Random import kotlin.test.Test import kotlin.test.assertEquals @@ -12,29 +13,29 @@ import kotlin.test.assertEquals internal class TestAsmOperationsSupport { @Test fun testUnaryOperationInvocation() { - val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile() + val expression = MstGroup { -bindSymbol("x") }.compileToExpression(DoubleField) val res = expression("x" to 2.0) assertEquals(-2.0, res) } @Test fun testBinaryOperationInvocation() { - val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile() + val expression = MstGroup { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField) val res = expression("x" to 2.0) assertEquals(-1.0, res) } @Test fun testConstProductInvocation() { - val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0) + val res = MstField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0) assertEquals(4.0, res) } @Test fun testMultipleCalls() { val e = - DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) } - .compile() + MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) } + .compileToExpression(DoubleField) val r = Random(0) var s = 0.0 repeat(1000000) { s += e("x" to r.nextDouble()) } diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmSpecialization.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmSpecialization.kt index f485260c9..75a3ffaee 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmSpecialization.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmSpecialization.kt @@ -1,53 +1,63 @@ package space.kscience.kmath.asm -import space.kscience.kmath.ast.mstInField +import space.kscience.kmath.ast.MstExtendedField import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals internal class TestAsmSpecialization { @Test fun testUnaryPlus() { - val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile() + val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField) assertEquals(2.0, expr("x" to 2.0)) } @Test fun testUnaryMinus() { - val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile() + val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField) assertEquals(-2.0, expr("x" to 2.0)) } @Test fun testAdd() { - val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile() + val expr = MstExtendedField { + binaryOperationFunction("+")(bindSymbol("x"), + bindSymbol("x")) + }.compileToExpression(DoubleField) assertEquals(4.0, expr("x" to 2.0)) } @Test fun testSine() { - val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile() + val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField) assertEquals(0.0, expr("x" to 0.0)) } @Test fun testMinus() { - val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile() + val expr = MstExtendedField { + binaryOperationFunction("-")(bindSymbol("x"), + bindSymbol("x")) + }.compileToExpression(DoubleField) assertEquals(0.0, expr("x" to 2.0)) } @Test fun testDivide() { - val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile() + val expr = MstExtendedField { + binaryOperationFunction("/")(bindSymbol("x"), + bindSymbol("x")) + }.compileToExpression(DoubleField) assertEquals(1.0, expr("x" to 2.0)) } @Test fun testPower() { - val expr = DoubleField - .mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) } - .compile() + val expr = MstExtendedField { + binaryOperationFunction("pow")(bindSymbol("x"), number(2)) + }.compileToExpression(DoubleField) assertEquals(4.0, expr("x" to 2.0)) } diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmVariables.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmVariables.kt index d1aaefffe..144d63eea 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmVariables.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/asm/TestAsmVariables.kt @@ -1,8 +1,9 @@ package space.kscience.kmath.asm -import space.kscience.kmath.ast.mstInRing +import space.kscience.kmath.ast.MstRing import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.ByteRing +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith internal class TestAsmVariables { @Test fun testVariable() { - val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() + val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing) assertEquals(1.toByte(), expr("x" to 1.toByte())) } @Test fun testUndefinedVariableFails() { - val expr = ByteRing.mstInRing { bindSymbol("x") }.compile() + val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing) assertFailsWith { expr() } } } diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt index 3d5449043..74f5e7e10 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt @@ -2,9 +2,9 @@ package space.kscience.kmath.ast import space.kscience.kmath.complex.Complex import space.kscience.kmath.complex.ComplexField -import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals @@ -18,7 +18,7 @@ internal class ParserTest { @Test fun `evaluate MSTExpression`() { - val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() + val res = MstField.invoke { number(2) + number(2) * (number(2) + number(2)) }.interpret(ComplexField) assertEquals(Complex(10.0, 0.0), res) } diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt index 1275b0c90..d5b55e031 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt @@ -3,8 +3,9 @@ package space.kscience.kmath.kotlingrad import edu.umontreal.kotlingrad.api.SFun import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MstAlgebra -import space.kscience.kmath.ast.MstExpression +import space.kscience.kmath.ast.interpret import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.Expression import space.kscience.kmath.misc.Symbol import space.kscience.kmath.operations.NumericAlgebra @@ -18,38 +19,26 @@ import space.kscience.kmath.operations.NumericAlgebra * @param A the [NumericAlgebra] of [T]. * @property expr the underlying [MstExpression]. */ -public inline class DifferentiableMstExpression( - public val expr: MstExpression, -) : DifferentiableExpression> where A : NumericAlgebra { +public class DifferentiableMstExpression>( + public val algebra: A, + public val mst: MST, +) : DifferentiableExpression> { - public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst)) + public override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) - /** - * The [MstExpression.algebra] of [expr]. - */ - public val algebra: A - get() = expr.algebra - - /** - * The [MstExpression.mst] of [expr]. - */ - public val mst: MST - get() = expr.mst - - public override fun invoke(arguments: Map): T = expr(arguments) - - public override fun derivativeOrNull(symbols: List): MstExpression = MstExpression( - algebra, - symbols.map(Symbol::identity) - .map(MstAlgebra::bindSymbol) - .map { it.toSVar>() } - .fold(mst.toSFun(), SFun>::d) - .toMst(), - ) + public override fun derivativeOrNull(symbols: List): DifferentiableMstExpression = + DifferentiableMstExpression( + algebra, + symbols.map(Symbol::identity) + .map(MstAlgebra::bindSymbol) + .map { it.toSVar>() } + .fold(mst.toSFun(), SFun>::d) + .toMst(), + ) } /** - * Wraps this [MstExpression] into [DifferentiableMstExpression]. + * Wraps this [MST] into [DifferentiableMstExpression]. */ -public fun > MstExpression.differentiable(): DifferentiableMstExpression = - DifferentiableMstExpression(this) +public fun > MST.toDiffExpression(algebra: A): DifferentiableMstExpression = + DifferentiableMstExpression(algebra, this) diff --git a/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt b/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt index 7cd3276b8..c4c25d789 100644 --- a/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt +++ b/kmath-kotlingrad/src/test/kotlin/space/kscience/kmath/kotlingrad/AdaptingTests.kt @@ -1,9 +1,8 @@ package space.kscience.kmath.kotlingrad import edu.umontreal.kotlingrad.api.* -import space.kscience.kmath.asm.compile +import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.ast.MstAlgebra -import space.kscience.kmath.ast.MstExpression import space.kscience.kmath.ast.parseMath import space.kscience.kmath.expressions.invoke import space.kscience.kmath.operations.DoubleField @@ -43,8 +42,8 @@ internal class AdaptingTests { fun simpleFunctionDerivative() { val x = MstAlgebra.bindSymbol("x").toSVar>() val quadratic = "x^2-4*x-44".parseMath().toSFun>() - val actualDerivative = MstExpression(DoubleField, quadratic.d(x).toMst()).compile() - val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile() + val actualDerivative = quadratic.d(x).toMst().compileToExpression(DoubleField) + val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField) assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0)) } @@ -52,12 +51,11 @@ internal class AdaptingTests { fun moreComplexDerivative() { val x = MstAlgebra.bindSymbol("x").toSVar>() val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun>() - val actualDerivative = MstExpression(DoubleField, composition.d(x).toMst()).compile() + val actualDerivative = composition.d(x).toMst().compileToExpression(DoubleField) + + val expectedDerivative = + "-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath().compileToExpression(DoubleField) - val expectedDerivative = MstExpression( - DoubleField, - "-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath() - ).compile() assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1)) } From a91d468b743c9a6df90bbd2cc3865aefd4241992 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 1 Apr 2021 21:27:30 +0300 Subject: [PATCH 4/9] Refactor Algebra and ExpressionAlgebra. Introduce bindSymbolOrNull on the top level --- .../ExpressionsInterpretersBenchmark.kt | 14 +++-- .../kmath/commons/fit/fitWithAutoDiff.kt | 4 +- .../kotlin/space/kscience/kmath/ast/MST.kt | 8 +-- .../space/kscience/kmath/ast/MstAlgebra.kt | 11 ++-- .../space/kscisnce/kmath/ast/InterpretTest.kt | 22 ++++++++ .../space/kscience/kmath/ast/ParserTest.kt | 2 +- .../DerivativeStructureExpression.kt | 10 ++-- .../DerivativeStructureExpressionTest.kt | 2 +- .../commons/optimization/OptimizeTest.kt | 5 +- .../space/kscience/kmath/complex/Complex.kt | 3 +- .../kscience/kmath/complex/Quaternion.kt | 4 +- .../complex/ExpressionFieldForComplexTest.kt | 2 +- kmath-core/api/kmath-core.api | 56 ++++++++++++++++--- .../kscience/kmath/expressions/Expression.kt | 17 +----- .../FunctionalExpressionAlgebra.kt | 9 ++- .../kmath/expressions/SimpleAutoDiff.kt | 2 +- .../kscience/kmath/operations/Algebra.kt | 14 ++++- .../kmath/expressions/SimpleAutoDiffTest.kt | 1 + kmath-viktor/api/kmath-viktor.api | 2 + 19 files changed, 123 insertions(+), 65 deletions(-) create mode 100644 kmath-ast/src/commonTest/kotlin/space/kscisnce/kmath/ast/InterpretTest.kt diff --git a/examples/src/benchmarks/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt index 2438e3979..ad2a57597 100644 --- a/examples/src/benchmarks/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/benchmarks/kotlin/space/kscience/kmath/benchmarks/ExpressionsInterpretersBenchmark.kt @@ -4,14 +4,16 @@ import kotlinx.benchmark.Benchmark import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Scope import kotlinx.benchmark.State -import space.kscience.kmath.asm.compile -import space.kscience.kmath.ast.mstInField +import space.kscience.kmath.asm.compileToExpression +import space.kscience.kmath.ast.MstField +import space.kscience.kmath.ast.toExpression import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.expressionInField import space.kscience.kmath.expressions.invoke import space.kscience.kmath.misc.symbol import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.bindSymbol +import space.kscience.kmath.operations.invoke import kotlin.random.Random @State(Scope.Benchmark) @@ -28,20 +30,20 @@ internal class ExpressionsInterpretersBenchmark { @Benchmark fun mstExpression(blackhole: Blackhole) { - val expr = algebra.mstInField { + val expr = MstField { val x = bindSymbol(x) x * 2.0 + number(2.0) / x - 16.0 - } + }.toExpression(algebra) invokeAndSum(expr, blackhole) } @Benchmark fun asmExpression(blackhole: Blackhole) { - val expr = algebra.mstInField { + val expr = MstField { val x = bindSymbol(x) x * 2.0 + number(2.0) / x - 16.0 - }.compile() + }.compileToExpression(algebra) invokeAndSum(expr, blackhole) } diff --git a/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt b/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt index 02534ac98..813310680 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/commons/fit/fitWithAutoDiff.kt @@ -62,8 +62,8 @@ suspend fun main() { // compute differentiable chi^2 sum for given model ax^2 + bx + c val chi2 = FunctionOptimization.chiSquared(x, y, yErr) { x1 -> //bind variables to autodiff context - val a = bind(a) - val b = bind(b) + val a = bindSymbol(a) + val b = bindSymbol(b) //Include default value for c if it is not provided as a parameter val c = bindSymbolOrNull(c) ?: one a * x1.pow(2) + b * x1 + c diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt index b8c2aadf7..4c37b09f4 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt @@ -58,7 +58,7 @@ public fun Algebra.evaluate(node: MST): T = when (node) { is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this") - is MST.Symbolic -> bindSymbol(node.value) + is MST.Symbolic -> bindSymbol(node.value) ?: error("Symbol '${node.value}' is not supported in $this") is MST.Unary -> when { this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value)) @@ -80,11 +80,7 @@ public fun Algebra.evaluate(node: MST): T = when (node) { } internal class InnerAlgebra(val algebra: Algebra, val arguments: Map) : NumericAlgebra { - override fun bindSymbol(value: String): T = try { - algebra.bindSymbol(value) - } catch (ignored: IllegalStateException) { - null - } ?: arguments.getValue(StringSymbol(value)) + override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)] override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt index c1aeae90e..edac0f9bd 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt @@ -8,7 +8,8 @@ import space.kscience.kmath.operations.* */ public object MstAlgebra : NumericAlgebra { public override fun number(value: Number): MST.Numeric = MST.Numeric(value) - public override fun bindSymbol(value: String): MST.Symbolic = MST.Symbolic(value) + public override fun bindSymbolOrNull(value: String): MST.Symbolic = MST.Symbolic(value) + override fun bindSymbol(value: String): MST.Symbolic = bindSymbolOrNull(value) public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = { arg -> MST.Unary(operation, arg) } @@ -24,7 +25,7 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { public override val zero: MST.Numeric = number(0.0) public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) - public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value) + public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value) public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b) public override operator fun MST.unaryPlus(): MST.Unary = unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this) @@ -54,7 +55,7 @@ public object MstRing : Ring, NumbersAddOperations, ScaleOperations, NumbersAddOperations, ScaleOperations< public override val one: MST.Numeric get() = MstRing.one - public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value) + public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value) public override fun number(value: Number): MST.Numeric = MstRing.number(value) public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) @@ -112,7 +113,7 @@ public object MstExtendedField : ExtendedField, NumericAlgebra { public override val zero: MST.Numeric get() = MstField.zero public override val one: MST.Numeric get() = MstField.one - public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value) + public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value) public override fun number(value: Number): MST.Numeric = MstRing.number(value) public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) diff --git a/kmath-ast/src/commonTest/kotlin/space/kscisnce/kmath/ast/InterpretTest.kt b/kmath-ast/src/commonTest/kotlin/space/kscisnce/kmath/ast/InterpretTest.kt new file mode 100644 index 000000000..1b8ec1490 --- /dev/null +++ b/kmath-ast/src/commonTest/kotlin/space/kscisnce/kmath/ast/InterpretTest.kt @@ -0,0 +1,22 @@ +package space.kscisnce.kmath.ast + +import space.kscience.kmath.ast.MstField +import space.kscience.kmath.ast.toExpression +import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.misc.Symbol.Companion.x +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.bindSymbol +import space.kscience.kmath.operations.invoke +import kotlin.test.Test + +class InterpretTest { + + @Test + fun interpretation(){ + val expr = MstField { + val x = bindSymbol(x) + x * 2.0 + number(2.0) / x - 16.0 + }.toExpression(DoubleField) + expr(x to 2.2) + } +} \ No newline at end of file diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt index 74f5e7e10..2b83e566e 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/ParserTest.kt @@ -40,7 +40,7 @@ internal class ParserTest { @Test fun `evaluate MST with binary function`() { val magicalAlgebra = object : Algebra { - override fun bindSymbol(value: String): String = value + override fun bindSymbolOrNull(value: String): String = value override fun unaryOperationFunction(operation: String): (arg: String) -> String { throw NotImplementedError() diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 58e9687e5..76f6c6ff5 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -2,7 +2,6 @@ package space.kscience.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import space.kscience.kmath.expressions.* -import space.kscience.kmath.misc.StringSymbol import space.kscience.kmath.misc.Symbol import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.operations.ExtendedField @@ -51,11 +50,11 @@ public class DerivativeStructureField( override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) - public override fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] + override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[value] + override fun bindSymbol(value: String): DerivativeStructureSymbol = variables.getValue(value) - public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) - - override fun bindSymbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value)) + public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] + public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) public fun DerivativeStructure.derivative(symbols: List): Double { require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" } @@ -108,7 +107,6 @@ public class DerivativeStructureField( } } - /** * A constructs that creates a derivative structure with required order on-demand */ diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt index b19eb5950..ad0c0b7eb 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpressionTest.kt @@ -27,7 +27,7 @@ internal class AutoDiffTest { @Test fun derivativeStructureFieldTest() { diff(2, x to 1.0, y to 1.0) { - val x = bind(x)//by binding() + val x = bindSymbol(x)//by binding() val y = bindSymbol("y") val z = x * (-sin(x * y) + y) + 2.0 println(z.derivative(x)) diff --git a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt index a51c407c2..de22c066b 100644 --- a/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -14,7 +14,8 @@ internal class OptimizeTest { val y by symbol val normal = DerivativeStructureExpression { - exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2) + exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y) + .pow(2) / 2) } @Test @@ -58,7 +59,7 @@ internal class OptimizeTest { val chi2 = FunctionOptimization.chiSquared(x, y, yErr) { x1 -> val cWithDefault = bindSymbolOrNull(c) ?: one - bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault + bindSymbol(a) * x1.pow(2) + bindSymbol(b) * x1 + cWithDefault } val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt index a73fb0201..e98b41b9b 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt @@ -165,8 +165,7 @@ public object ComplexField : ExtendedField, Norm, Num public override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) - public override fun bindSymbol(value: String): Complex = - if (value == "i") i else super.bindSymbol(value) + public override fun bindSymbolOrNull(value: String): Complex? = if (value == "i") i else null } /** diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt index 9a0346ca7..a8189dfe8 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt @@ -165,11 +165,11 @@ public object QuaternionField : Field, Norm, public override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) public override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) - public override fun bindSymbol(value: String): Quaternion = when (value) { + public override fun bindSymbolOrNull(value: String): Quaternion? = when (value) { "i" -> i "j" -> j "k" -> k - else -> super.bindSymbol(value) + else -> null } override fun number(value: Number): Quaternion = value.toQuaternion() diff --git a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt index 3837b0d40..c08e73800 100644 --- a/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt +++ b/kmath-complex/src/commonTest/kotlin/space/kscience/kmath/complex/ExpressionFieldForComplexTest.kt @@ -1,9 +1,9 @@ package space.kscience.kmath.complex import space.kscience.kmath.expressions.FunctionalExpressionField -import space.kscience.kmath.expressions.bindSymbol import space.kscience.kmath.expressions.invoke import space.kscience.kmath.misc.symbol +import space.kscience.kmath.operations.bindSymbol import kotlin.test.Test import kotlin.test.assertEquals diff --git a/kmath-core/api/kmath-core.api b/kmath-core/api/kmath-core.api index e6f4697aa..f4724a50e 100644 --- a/kmath-core/api/kmath-core.api +++ b/kmath-core/api/kmath-core.api @@ -50,8 +50,6 @@ public abstract interface class space/kscience/kmath/expressions/Expression { } public abstract interface class space/kscience/kmath/expressions/ExpressionAlgebra : space/kscience/kmath/operations/Algebra { - public abstract fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; - public abstract fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object; public abstract fun const (Ljava/lang/Object;)Ljava/lang/Object; } @@ -59,6 +57,7 @@ public final class space/kscience/kmath/expressions/ExpressionAlgebra$DefaultImp public static fun binaryOperation (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperationFunction (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function1; } @@ -71,7 +70,6 @@ public final class space/kscience/kmath/expressions/ExpressionBuildersKt { } public final class space/kscience/kmath/expressions/ExpressionKt { - public static final fun bindSymbol (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object; public static final fun binding (Lspace/kscience/kmath/expressions/ExpressionAlgebra;)Lkotlin/properties/ReadOnlyProperty; public static final fun callByString (Lspace/kscience/kmath/expressions/Expression;[Lkotlin/Pair;)Ljava/lang/Object; public static final fun callBySymbol (Lspace/kscience/kmath/expressions/Expression;[Lkotlin/Pair;)Ljava/lang/Object; @@ -91,8 +89,8 @@ public abstract class space/kscience/kmath/expressions/FunctionalExpressionAlgeb public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression; - public synthetic fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object; - public fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/expressions/Expression; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression; public synthetic fun const (Ljava/lang/Object;)Ljava/lang/Object; public fun const (Ljava/lang/Object;)Lspace/kscience/kmath/expressions/Expression; public final fun getAlgebra ()Lspace/kscience/kmath/operations/Algebra; @@ -278,8 +276,8 @@ public class space/kscience/kmath/expressions/SimpleAutoDiffField : space/kscien public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/AutoDiffValue; - public synthetic fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object; - public fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/expressions/AutoDiffValue; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/expressions/AutoDiffValue; public synthetic fun const (Ljava/lang/Object;)Ljava/lang/Object; public fun const (Ljava/lang/Object;)Lspace/kscience/kmath/expressions/AutoDiffValue; public final fun const (Lkotlin/jvm/functions/Function1;)Lspace/kscience/kmath/expressions/AutoDiffValue; @@ -743,6 +741,8 @@ public class space/kscience/kmath/nd/BufferedGroupND : space/kscience/kmath/nd/B public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/BufferND; public synthetic fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND; public fun getBuffer (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/structures/Buffer; @@ -890,6 +890,7 @@ public final class space/kscience/kmath/nd/FieldND$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun binaryOperationFunction (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public static fun div (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun div (Lspace/kscience/kmath/nd/FieldND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Number;)Lspace/kscience/kmath/nd/StructureND; public static fun div (Lspace/kscience/kmath/nd/FieldND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Object;)Lspace/kscience/kmath/nd/StructureND; @@ -935,6 +936,7 @@ public final class space/kscience/kmath/nd/GroupND$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun binaryOperationFunction (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public static fun invoke (Lspace/kscience/kmath/nd/GroupND;Lkotlin/jvm/functions/Function1;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun minus (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun minus (Lspace/kscience/kmath/nd/GroupND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Object;)Lspace/kscience/kmath/nd/StructureND; @@ -970,6 +972,7 @@ public final class space/kscience/kmath/nd/RingND$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun binaryOperationFunction (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public static fun invoke (Lspace/kscience/kmath/nd/RingND;Lkotlin/jvm/functions/Function1;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun minus (Lspace/kscience/kmath/nd/RingND;Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public static fun minus (Lspace/kscience/kmath/nd/RingND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Object;)Lspace/kscience/kmath/nd/StructureND; @@ -1118,6 +1121,7 @@ public abstract interface class space/kscience/kmath/operations/Algebra { public abstract fun binaryOperation (Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public abstract fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public abstract fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public abstract fun unaryOperation (Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun unaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function1; } @@ -1126,6 +1130,7 @@ public final class space/kscience/kmath/operations/Algebra$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function1; } @@ -1156,6 +1161,7 @@ public final class space/kscience/kmath/operations/AlgebraExtensionsKt { public final class space/kscience/kmath/operations/AlgebraKt { public static final fun bindSymbol (Lspace/kscience/kmath/operations/Algebra;Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object; + public static final fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Algebra;Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object; public static final fun invoke (Lspace/kscience/kmath/operations/Algebra;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; } @@ -1201,6 +1207,8 @@ public final class space/kscience/kmath/operations/BigIntField : space/kscience/ public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/operations/BigInt; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/operations/BigInt; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public fun div (Lspace/kscience/kmath/operations/BigInt;Ljava/lang/Number;)Lspace/kscience/kmath/operations/BigInt; @@ -1274,6 +1282,8 @@ public final class space/kscience/kmath/operations/ByteRing : space/kscience/kma public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Byte; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Byte; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public fun getOne ()Ljava/lang/Byte; public synthetic fun getOne ()Ljava/lang/Object; public fun getZero ()Ljava/lang/Byte; @@ -1326,6 +1336,8 @@ public final class space/kscience/kmath/operations/DoubleField : space/kscience/ public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Double; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Double; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public fun cos (D)Ljava/lang/Double; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; public fun cosh (D)Ljava/lang/Double; @@ -1426,6 +1438,7 @@ public final class space/kscience/kmath/operations/ExponentialOperations$Default public static fun binaryOperation (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function1; } @@ -1447,6 +1460,7 @@ public final class space/kscience/kmath/operations/ExtendedField$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;)Ljava/lang/Object; public static fun cosh (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/Object;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; @@ -1480,6 +1494,7 @@ public final class space/kscience/kmath/operations/ExtendedFieldOperations$Defau public static fun binaryOperation (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun minus (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun plus (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; @@ -1501,6 +1516,7 @@ public final class space/kscience/kmath/operations/Field$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/Field;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/Field;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun leftSideNumberOperation (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object; @@ -1534,6 +1550,7 @@ public final class space/kscience/kmath/operations/FieldOperations$DefaultImpls public static fun binaryOperation (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun minus (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun plus (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; @@ -1564,6 +1581,8 @@ public final class space/kscience/kmath/operations/FloatField : space/kscience/k public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Float; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Float; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public fun cos (F)Ljava/lang/Float; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; public fun cosh (F)Ljava/lang/Float; @@ -1637,6 +1656,7 @@ public final class space/kscience/kmath/operations/Group$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;)Ljava/lang/Object; public static fun minus (Lspace/kscience/kmath/operations/Group;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun plus (Lspace/kscience/kmath/operations/Group;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; @@ -1666,6 +1686,7 @@ public final class space/kscience/kmath/operations/GroupOperations$DefaultImpls public static fun binaryOperation (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun minus (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun plus (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; @@ -1682,6 +1703,8 @@ public final class space/kscience/kmath/operations/IntRing : space/kscience/kmat public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Integer; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Integer; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public fun getOne ()Ljava/lang/Integer; public synthetic fun getOne ()Ljava/lang/Object; public fun getZero ()Ljava/lang/Integer; @@ -1732,6 +1755,8 @@ public abstract class space/kscience/kmath/operations/JBigDecimalFieldBase : spa public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Ljava/math/BigDecimal; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/math/BigDecimal; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public fun div (Ljava/math/BigDecimal;Ljava/lang/Number;)Ljava/math/BigDecimal; @@ -1788,6 +1813,8 @@ public final class space/kscience/kmath/operations/JBigIntegerField : space/ksci public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Ljava/math/BigInteger; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/math/BigInteger; public synthetic fun getOne ()Ljava/lang/Object; public fun getOne ()Ljava/math/BigInteger; public synthetic fun getZero ()Ljava/lang/Object; @@ -1829,6 +1856,8 @@ public final class space/kscience/kmath/operations/LongRing : space/kscience/kma public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Long; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Long; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public fun getOne ()Ljava/lang/Long; public synthetic fun getOne ()Ljava/lang/Object; public fun getZero ()Ljava/lang/Long; @@ -1868,6 +1897,7 @@ public final class space/kscience/kmath/operations/NumbersAddOperations$DefaultI public static fun binaryOperation (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun leftSideNumberOperation (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object; public static fun leftSideNumberOperationFunction (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun minus (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object; @@ -1895,6 +1925,7 @@ public final class space/kscience/kmath/operations/NumericAlgebra$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Ljava/lang/Object; public static fun leftSideNumberOperation (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object; public static fun leftSideNumberOperationFunction (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun rightSideNumberOperation (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; @@ -1924,6 +1955,7 @@ public final class space/kscience/kmath/operations/PowerOperations$DefaultImpls public static fun binaryOperation (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun pow (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public static fun sqrt (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; @@ -1938,6 +1970,7 @@ public final class space/kscience/kmath/operations/Ring$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;)Ljava/lang/Object; public static fun minus (Lspace/kscience/kmath/operations/Ring;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun plus (Lspace/kscience/kmath/operations/Ring;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun times (Lspace/kscience/kmath/operations/Ring;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; @@ -1962,6 +1995,7 @@ public final class space/kscience/kmath/operations/RingOperations$DefaultImpls { public static fun binaryOperation (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun minus (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun plus (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun times (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; @@ -1981,6 +2015,7 @@ public final class space/kscience/kmath/operations/ScaleOperations$DefaultImpls public static fun binaryOperation (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun div (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public static fun times (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object; public static fun times (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; @@ -1997,6 +2032,8 @@ public final class space/kscience/kmath/operations/ShortRing : space/kscience/km public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Short; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Short; public synthetic fun getOne ()Ljava/lang/Object; public fun getOne ()Ljava/lang/Short; public synthetic fun getZero ()Ljava/lang/Object; @@ -2057,6 +2094,7 @@ public final class space/kscience/kmath/operations/TrigonometricOperations$Defau public static fun binaryOperation (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public static fun bindSymbol (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Ljava/lang/Object; + public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Ljava/lang/Object; public static fun unaryOperation (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object; public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function1; } @@ -2147,6 +2185,8 @@ public final class space/kscience/kmath/structures/DoubleBufferField : space/ksc public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; public fun cos-Udx-57Q (Lspace/kscience/kmath/structures/Buffer;)[D public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object; @@ -2232,6 +2272,8 @@ public final class space/kscience/kmath/structures/DoubleBufferFieldOperations : public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; public fun cos-Udx-57Q (Lspace/kscience/kmath/structures/Buffer;)[D public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object; diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt index 7918f199e..fc49a0fae 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt @@ -55,15 +55,6 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = * @param E type of the actual expression state */ public interface ExpressionAlgebra : Algebra { - /** - * Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context. - */ - public fun bindSymbolOrNull(symbol: Symbol): E? - - /** - * Bind a string to a context using [StringSymbol] - */ - override fun bindSymbol(value: String): E = bindSymbol(StringSymbol(value)) /** * A constant expression which does not depend on arguments @@ -71,15 +62,9 @@ public interface ExpressionAlgebra : Algebra { public fun const(value: T): E } -/** - * Bind a given [Symbol] to this context variable and produce context-specific object. - */ -public fun ExpressionAlgebra.bindSymbol(symbol: Symbol): E = - bindSymbolOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this") - /** * Bind a symbol by name inside the [ExpressionAlgebra] */ public fun ExpressionAlgebra.binding(): ReadOnlyProperty = ReadOnlyProperty { _, property -> - bindSymbol(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist") + bindSymbol(property.name) ?: error("A variable with name ${property.name} does not exist") } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index ebd9e7f22..775a49aad 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -1,6 +1,6 @@ package space.kscience.kmath.expressions -import space.kscience.kmath.misc.Symbol +import space.kscience.kmath.misc.StringSymbol import space.kscience.kmath.operations.* /** @@ -19,8 +19,8 @@ public abstract class FunctionalExpressionAlgebra>( /** * Builds an Expression to access a variable. */ - public override fun bindSymbolOrNull(symbol: Symbol): Expression? = Expression { arguments -> - arguments[symbol] ?: error("Argument not found: $symbol") + override fun bindSymbolOrNull(value: String): Expression? = Expression { arguments -> + arguments[StringSymbol(value)] ?: error("Argument not found: $value") } /** @@ -101,8 +101,7 @@ public open class FunctionalExpressionRing>( public open class FunctionalExpressionField>( algebra: A, -) : FunctionalExpressionRing(algebra), Field>, - ScaleOperations> { +) : FunctionalExpressionRing(algebra), Field>, ScaleOperations> { /** * Builds an Expression of division an expression by another one. */ diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index d9be4a92e..d3b65107d 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -85,7 +85,7 @@ public open class SimpleAutoDiffField>( override fun hashCode(): Int = identity.hashCode() } - public override fun bindSymbolOrNull(symbol: Symbol): AutoDiffValue? = bindings[symbol.identity] + override fun bindSymbolOrNull(value: String): AutoDiffValue? = bindings[value] private fun getDerivative(variable: AutoDiffValue): T = (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index 492ec8e88..78ada6f5c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -23,10 +23,18 @@ public interface Algebra { * * In case if algebra can't parse the string, this method must throw [kotlin.IllegalStateException]. * + * Returns `null` if symbol could not be bound to the context + * * @param value the raw string. * @return an object. */ - public fun bindSymbol(value: String): T = error("Wrapping of '$value' is not supported in $this") + public fun bindSymbolOrNull(value: String): T? = null + + /** + * The same as [bindSymbolOrNull] but throws an error if symbol could not be bound + */ + public fun bindSymbol(value: String): T = + bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this") /** * Dynamically dispatches an unary operation with the certain name. @@ -91,7 +99,9 @@ public interface Algebra { binaryOperationFunction(operation)(left, right) } -public fun Algebra.bindSymbol(symbol: Symbol): T = bindSymbol(symbol.identity) +public fun Algebra.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity) + +public fun Algebra.bindSymbol(symbol: Symbol): T = bindSymbol(symbol.identity) /** * Call a block with an [Algebra] as receiver. diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt index 666db13d8..0cac510d0 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/expressions/SimpleAutoDiffTest.kt @@ -3,6 +3,7 @@ package space.kscience.kmath.expressions import space.kscience.kmath.misc.Symbol import space.kscience.kmath.misc.symbol import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.asBuffer import kotlin.math.E diff --git a/kmath-viktor/api/kmath-viktor.api b/kmath-viktor/api/kmath-viktor.api index 0b9ea1b48..e209c863c 100644 --- a/kmath-viktor/api/kmath-viktor.api +++ b/kmath-viktor/api/kmath-viktor.api @@ -46,6 +46,8 @@ public final class space/kscience/kmath/viktor/ViktorFieldND : space/kscience/km public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public synthetic fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND; public fun combine-WKhNzhk (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lorg/jetbrains/bio/viktor/F64Array; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; From e6921025d11b20d8bd18d29242035ae4dbdb8455 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Fri, 2 Apr 2021 16:46:12 +0700 Subject: [PATCH 5/9] Remove redundant try-catch expressions --- .../src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt | 2 +- .../src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt | 6 +----- .../src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt | 6 +----- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt index 4c37b09f4..538db0caa 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MST.kt @@ -58,7 +58,7 @@ public fun Algebra.evaluate(node: MST): T = when (node) { is MST.Numeric -> (this as? NumericAlgebra)?.number(node.value) ?: error("Numeric nodes are not supported by $this") - is MST.Symbolic -> bindSymbol(node.value) ?: error("Symbol '${node.value}' is not supported in $this") + is MST.Symbolic -> bindSymbol(node.value) is MST.Unary -> when { this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value)) diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt index 93b2d54c8..796ffce1e 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt @@ -14,11 +14,7 @@ import space.kscience.kmath.operations.NumericAlgebra internal fun MST.compileWith(algebra: Algebra): Expression { fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) { is Symbolic -> { - val symbol = try { - algebra.bindSymbol(node.value) - } catch (ignored: IllegalStateException) { - null - } + val symbol = algebra.bindSymbolOrNull(node.value) if (symbol != null) constant(symbol) diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 5324d74a1..ee2b6fb54 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -22,11 +22,7 @@ import space.kscience.kmath.operations.NumericAlgebra internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { fun AsmBuilder.visit(node: MST): Unit = when (node) { is Symbolic -> { - val symbol = try { - algebra.bindSymbol(node.value) - } catch (ignored: IllegalStateException) { - null - } + val symbol = algebra.bindSymbolOrNull(node.value) if (symbol != null) loadObjectConstant(symbol as Any) From f7e792faffe6d08c09d4d11f39be4265bb431256 Mon Sep 17 00:00:00 2001 From: darksnake Date: Fri, 2 Apr 2021 19:09:35 +0300 Subject: [PATCH 6/9] Add test for grid iteration. --- .../commonTest/kotlin/kaceince/kmath/real/GridTest.kt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt index 91ee517ab..a7c4d30e2 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/GridTest.kt @@ -15,4 +15,13 @@ class GridTest { assertEquals(6, grid.size) assertTrue { (grid - DoubleVector(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).norm < 1e-4 } } + + @Test + fun testIterateGrid(){ + var res = 0.0 + for(d in 0.0..1.0 step 0.2){ + res = d + } + assertEquals(1.0, res) + } } \ No newline at end of file From cf91da1a988d6bbc507a9c86a463b43f1ecf6af2 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Wed, 31 Mar 2021 21:51:54 +0700 Subject: [PATCH 7/9] Add pi and e constants, some unrelated changes --- .../kmath/structures/StreamDoubleFieldND.kt | 12 ++--- .../space/kscience/kmath/ast/MstAlgebra.kt | 14 +++--- .../ast/rendering/LatexSyntaxRenderer.kt | 1 + .../ast/rendering/MathMLSyntaxRenderer.kt | 1 + .../kmath/ast/rendering/MathRenderer.kt | 5 +- .../kmath/ast/rendering/MathSyntax.kt | 5 ++ .../kscience/kmath/ast/rendering/features.kt | 19 +++++++ .../kmath/ast/rendering/TestFeatures.kt | 5 ++ .../kscience/kmath/ast/rendering/TestLatex.kt | 5 +- .../kmath/ast/rendering/TestMathML.kt | 5 +- .../DerivativeStructureExpression.kt | 18 +++---- .../space/kscience/kmath/complex/Complex.kt | 8 +-- .../kscience/kmath/complex/ComplexFieldND.kt | 36 +++++++------- kmath-core/api/kmath-core.api | 19 +++++++ .../FunctionalExpressionAlgebra.kt | 19 +++++-- .../kmath/expressions/SimpleAutoDiff.kt | 6 ++- .../space/kscience/kmath/nd/AlgebraND.kt | 12 ++--- .../space/kscience/kmath/nd/DoubleFieldND.kt | 49 +++++++++---------- .../kscience/kmath/operations/Algebra.kt | 12 ++--- .../kmath/operations/AlgebraElements.kt | 7 +-- .../kmath/operations/NumericAlgebra.kt | 28 +++++++++-- .../kscience/kmath/operations/numbers.kt | 48 +++++++++--------- .../kscience/kmath/stat/StatisticTest.kt | 1 - kmath-viktor/api/kmath-viktor.api | 2 +- .../kmath/viktor/ViktorStructureND.kt | 28 +++++------ 25 files changed, 225 insertions(+), 140 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt index 6741209fc..162c63df9 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt @@ -1,6 +1,5 @@ package space.kscience.kmath.structures -import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExtendedField @@ -9,12 +8,10 @@ import java.util.* import java.util.stream.IntStream /** - * A demonstration implementation of NDField over Real using Java [DoubleStream] for parallel execution + * A demonstration implementation of NDField over Real using Java [java.util.stream.DoubleStream] for parallel + * execution. */ -@OptIn(UnstableKMathAPI::class) -class StreamDoubleFieldND( - override val shape: IntArray, -) : FieldND, +class StreamDoubleFieldND(override val shape: IntArray) : FieldND, NumbersAddOperations>, ExtendedField> { @@ -38,7 +35,6 @@ class StreamDoubleFieldND( else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } } - override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val index = strides.index(offset) @@ -104,4 +100,4 @@ class StreamDoubleFieldND( override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } } -fun AlgebraND.Companion.realWithStream(vararg shape: Int): StreamDoubleFieldND = StreamDoubleFieldND(shape) \ No newline at end of file +fun AlgebraND.Companion.realWithStream(vararg shape: Int): StreamDoubleFieldND = StreamDoubleFieldND(shape) diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt index edac0f9bd..33fca7521 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/MstAlgebra.kt @@ -49,9 +49,10 @@ public object MstGroup : Group, NumericAlgebra, ScaleOperations { /** * [Ring] over [MST] nodes. */ +@Suppress("OVERRIDE_BY_INLINE") @OptIn(UnstableKMathAPI::class) public object MstRing : Ring, NumbersAddOperations, ScaleOperations { - public override val zero: MST.Numeric get() = MstGroup.zero + public override inline val zero: MST.Numeric get() = MstGroup.zero public override val one: MST.Numeric = number(1.0) public override fun number(value: Number): MST.Numeric = MstGroup.number(value) @@ -78,11 +79,11 @@ public object MstRing : Ring, NumbersAddOperations, ScaleOperations, NumbersAddOperations, ScaleOperations { - public override val zero: MST.Numeric get() = MstRing.zero - - public override val one: MST.Numeric get() = MstRing.one + public override inline val zero: MST.Numeric get() = MstRing.zero + public override inline val one: MST.Numeric get() = MstRing.one public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value) public override fun number(value: Number): MST.Numeric = MstRing.number(value) @@ -109,9 +110,10 @@ public object MstField : Field, NumbersAddOperations, ScaleOperations< /** * [ExtendedField] over [MST] nodes. */ +@Suppress("OVERRIDE_BY_INLINE") public object MstExtendedField : ExtendedField, NumericAlgebra { - public override val zero: MST.Numeric get() = MstField.zero - public override val one: MST.Numeric get() = MstField.one + public override inline val zero: MST.Numeric get() = MstField.zero + public override inline val one: MST.Numeric get() = MstField.one public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value) public override fun number(value: Number): MST.Numeric = MstRing.number(value) diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt index 914da6d9f..5d40097b6 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/LatexSyntaxRenderer.kt @@ -34,6 +34,7 @@ public object LatexSyntaxRenderer : SyntaxRenderer { is SpecialSymbolSyntax -> when (node.kind) { SpecialSymbolSyntax.Kind.INFINITY -> append("\\infty") + SpecialSymbolSyntax.Kind.SMALL_PI -> append("\\pi") } is OperandSyntax -> { diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt index 6f194be86..d1d3c82e3 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathMLSyntaxRenderer.kt @@ -48,6 +48,7 @@ public object MathMLSyntaxRenderer : SyntaxRenderer { is SpecialSymbolSyntax -> when (node.kind) { SpecialSymbolSyntax.Kind.INFINITY -> tag("mo") { append("∞") } + SpecialSymbolSyntax.Kind.SMALL_PI -> tag("mo") { append("π") } } is OperandSyntax -> if (node.parentheses) { diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt index afdf12b04..14e14404c 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathRenderer.kt @@ -85,9 +85,10 @@ public open class FeaturedMathRendererWithPostProcess( BinaryOperator.Default, UnaryOperator.Default, - // Pretty printing for numerics + // Pretty printing for some objects PrettyPrintFloats.Default, PrettyPrintIntegers.Default, + PrettyPrintPi.Default, // Printing terminal nodes as string PrintNumeric, @@ -96,7 +97,7 @@ public open class FeaturedMathRendererWithPostProcess( listOf( SimplifyParentheses.Default, BetterMultiplication, - ) + ), ) } } diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt index 4c85adcfc..febb6e5af 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/MathSyntax.kt @@ -101,6 +101,11 @@ public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax() * The infinity (∞) symbol. */ INFINITY, + + /** + * The Pi (π) symbol. + */ + SMALL_PI; } } diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt index 6e66d3ca3..95108ba45 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/rendering/features.kt @@ -118,6 +118,25 @@ public class PrettyPrintIntegers(public val types: Set>) : Re } } +/** + * Special printing for symbols meaning Pi. + * + * @property symbols The allowed symbols. + */ +public class PrettyPrintPi(public val symbols: Set) : RenderFeature { + public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? { + if (node !is MST.Symbolic || node.value !in symbols) return null + return SpecialSymbolSyntax(kind = SpecialSymbolSyntax.Kind.SMALL_PI) + } + + public companion object { + /** + * The default instance containing `pi`. + */ + public val Default: PrettyPrintPi = PrettyPrintPi(setOf("pi")) + } +} + /** * Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is * not [MST.Unary]. diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt index b10f7ed4e..5850ea23d 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestFeatures.kt @@ -45,6 +45,11 @@ internal class TestFeatures { testLatex(Numeric(-42), "-42") } + @Test + fun prettyPrintPi() { + testLatex("pi", "\\pi") + } + @Test fun binaryPlus() = testLatex("2+2", "2+2") diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt index 9c1009042..599bee436 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestLatex.kt @@ -16,7 +16,10 @@ internal class TestLatex { fun operatorName() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") @Test - fun specialSymbol() = testLatex(MST.Numeric(Double.POSITIVE_INFINITY), "\\infty") + fun specialSymbol() { + testLatex(MST.Numeric(Double.POSITIVE_INFINITY), "\\infty") + testLatex("pi", "\\pi") + } @Test fun operand() { diff --git a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt index c9a462840..6fadef6cd 100644 --- a/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt +++ b/kmath-ast/src/jvmTest/kotlin/space/kscience/kmath/ast/rendering/TestMathML.kt @@ -19,7 +19,10 @@ internal class TestMathML { ) @Test - fun specialSymbol() = testMathML(MST.Numeric(Double.POSITIVE_INFINITY), "") + fun specialSymbol() { + testMathML(MST.Numeric(Double.POSITIVE_INFINITY), "") + testMathML("pi", "π") + } @Test fun operand() { diff --git a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 76f6c6ff5..4f229cabd 100644 --- a/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/space/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -24,7 +24,7 @@ public class DerivativeStructureField( public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } - override fun number(value: Number): DerivativeStructure = const(value.toDouble()) + public override fun number(value: Number): DerivativeStructure = const(value.toDouble()) /** * A class that implements both [DerivativeStructure] and a [Symbol] @@ -35,10 +35,10 @@ public class DerivativeStructureField( symbol: Symbol, value: Double, ) : DerivativeStructure(size, order, index, value), Symbol { - override val identity: String = symbol.identity - override fun toString(): String = identity - override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity - override fun hashCode(): Int = identity.hashCode() + public override val identity: String = symbol.identity + public override fun toString(): String = identity + public override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity + public override fun hashCode(): Int = identity.hashCode() } /** @@ -48,10 +48,10 @@ public class DerivativeStructureField( key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) }.toMap() - override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) + public override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) - override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[value] - override fun bindSymbol(value: String): DerivativeStructureSymbol = variables.getValue(value) + public override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[value] + public override fun bindSymbol(value: String): DerivativeStructureSymbol = variables.getValue(value) public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity) @@ -64,7 +64,7 @@ public class DerivativeStructureField( public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList()) - override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() + public override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt index e98b41b9b..aa97c6463 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt @@ -121,8 +121,8 @@ public object ComplexField : ExtendedField, Norm, Num /** * Adds complex number to real one. * - * @receiver the addend. - * @param c the augend. + * @receiver the augend. + * @param c the addend. * @return the sum. */ public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) @@ -139,8 +139,8 @@ public object ComplexField : ExtendedField, Norm, Num /** * Adds real number to complex one. * - * @receiver the addend. - * @param d the augend. + * @receiver the augend. + * @param d the addend. * @return the sum. */ public operator fun Complex.plus(d: Double): Complex = d + this diff --git a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt index d11f2b7db..701b77df1 100644 --- a/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt +++ b/kmath-complex/src/commonMain/kotlin/space/kscience/kmath/complex/ComplexFieldND.kt @@ -22,10 +22,10 @@ public class ComplexFieldND( NumbersAddOperations>, ExtendedField> { - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } + public override val zero: BufferND by lazy { produce { zero } } + public override val one: BufferND by lazy { produce { one } } - override fun number(value: Number): BufferND { + public override fun number(value: Number): BufferND { val d = value.toComplex() // minimize conversions return produce { d } } @@ -76,25 +76,25 @@ public class ComplexFieldND( // return BufferedNDFieldElement(this, buffer) // } - override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } + public override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } - override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } + public override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } - override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } + public override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } - override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } - override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } - override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } - override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } - override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } - override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } + public override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } + public override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } + public override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } + public override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } + public override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } + public override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } - override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } - override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } - override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } - override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } - override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } - override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } + public override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } + public override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } + public override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } + public override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } + public override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } + public override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } } diff --git a/kmath-core/api/kmath-core.api b/kmath-core/api/kmath-core.api index f4724a50e..f76372a3d 100644 --- a/kmath-core/api/kmath-core.api +++ b/kmath-core/api/kmath-core.api @@ -121,6 +121,8 @@ public class space/kscience/kmath/expressions/FunctionalExpressionExtendedField public synthetic fun atanh (Ljava/lang/Object;)Ljava/lang/Object; public fun atanh (Lspace/kscience/kmath/expressions/Expression;)Lspace/kscience/kmath/expressions/Expression; public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; + public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; public fun cos (Lspace/kscience/kmath/expressions/Expression;)Lspace/kscience/kmath/expressions/Expression; public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object; @@ -152,6 +154,8 @@ public class space/kscience/kmath/expressions/FunctionalExpressionExtendedField public class space/kscience/kmath/expressions/FunctionalExpressionField : space/kscience/kmath/expressions/FunctionalExpressionRing, space/kscience/kmath/operations/Field, space/kscience/kmath/operations/ScaleOperations { public fun (Lspace/kscience/kmath/operations/Field;)V public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public final fun div (Ljava/lang/Object;Lspace/kscience/kmath/expressions/Expression;)Lspace/kscience/kmath/expressions/Expression; @@ -235,6 +239,8 @@ public final class space/kscience/kmath/expressions/SimpleAutoDiffExtendedField public fun atan (Lspace/kscience/kmath/expressions/AutoDiffValue;)Lspace/kscience/kmath/expressions/AutoDiffValue; public synthetic fun atanh (Ljava/lang/Object;)Ljava/lang/Object; public fun atanh (Lspace/kscience/kmath/expressions/AutoDiffValue;)Lspace/kscience/kmath/expressions/AutoDiffValue; + public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/AutoDiffValue; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; public fun cos (Lspace/kscience/kmath/expressions/AutoDiffValue;)Lspace/kscience/kmath/expressions/AutoDiffValue; public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object; @@ -708,6 +714,8 @@ public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/St public class space/kscience/kmath/nd/BufferedFieldND : space/kscience/kmath/nd/BufferedRingND, space/kscience/kmath/nd/FieldND { public fun ([ILspace/kscience/kmath/operations/Field;Lkotlin/jvm/functions/Function2;)V public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object; public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object; public fun div (Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; @@ -821,6 +829,8 @@ public final class space/kscience/kmath/nd/DoubleFieldND : space/kscience/kmath/ public fun atan (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/BufferND; public synthetic fun atanh (Ljava/lang/Object;)Ljava/lang/Object; public fun atanh (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/BufferND; + public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/BufferND; public synthetic fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND; public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object; @@ -997,6 +1007,8 @@ public final class space/kscience/kmath/nd/ShapeMismatchException : java/lang/Ru public final class space/kscience/kmath/nd/ShortRingND : space/kscience/kmath/nd/BufferedRingND, space/kscience/kmath/operations/NumbersAddOperations { public fun ([I)V + public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; + public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND; public synthetic fun getOne ()Ljava/lang/Object; public fun getOne ()Lspace/kscience/kmath/nd/BufferND; public synthetic fun getZero ()Ljava/lang/Object; @@ -1447,6 +1459,7 @@ public abstract interface class space/kscience/kmath/operations/ExtendedField : public abstract fun acosh (Ljava/lang/Object;)Ljava/lang/Object; public abstract fun asinh (Ljava/lang/Object;)Ljava/lang/Object; public abstract fun atanh (Ljava/lang/Object;)Ljava/lang/Object; + public abstract fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object; public abstract fun cosh (Ljava/lang/Object;)Ljava/lang/Object; public abstract fun rightSideNumberOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public abstract fun sinh (Ljava/lang/Object;)Ljava/lang/Object; @@ -1914,6 +1927,7 @@ public final class space/kscience/kmath/operations/NumbersAddOperations$DefaultI } public abstract interface class space/kscience/kmath/operations/NumericAlgebra : space/kscience/kmath/operations/Algebra { + public abstract fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object; public abstract fun leftSideNumberOperation (Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object; public abstract fun leftSideNumberOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2; public abstract fun number (Ljava/lang/Number;)Ljava/lang/Object; @@ -1934,6 +1948,11 @@ public final class space/kscience/kmath/operations/NumericAlgebra$DefaultImpls { public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function1; } +public final class space/kscience/kmath/operations/NumericAlgebraKt { + public static final fun getE (Lspace/kscience/kmath/operations/NumericAlgebra;)Ljava/lang/Object; + public static final fun getPi (Lspace/kscience/kmath/operations/NumericAlgebra;)Ljava/lang/Object; +} + public final class space/kscience/kmath/operations/OptionalOperationsKt { } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 775a49aad..9fb8f28c8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -19,8 +19,10 @@ public abstract class FunctionalExpressionAlgebra>( /** * Builds an Expression to access a variable. */ - override fun bindSymbolOrNull(value: String): Expression? = Expression { arguments -> - arguments[StringSymbol(value)] ?: error("Argument not found: $value") + public override fun bindSymbolOrNull(value: String): Expression? = Expression { arguments -> + algebra.bindSymbolOrNull(value) + ?: arguments[StringSymbol(value)] + ?: error("Symbol '$value' is not supported in $this") } /** @@ -49,7 +51,7 @@ public open class FunctionalExpressionGroup>( ) : FunctionalExpressionAlgebra(algebra), Group> { public override val zero: Expression get() = const(algebra.zero) - override fun Expression.unaryMinus(): Expression = + public override fun Expression.unaryMinus(): Expression = unaryOperation(GroupOperations.MINUS_OPERATION, this) /** @@ -117,16 +119,21 @@ public open class FunctionalExpressionField>( public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) - override fun scale(a: Expression, value: Double): Expression = algebra { + public override fun scale(a: Expression, value: Double): Expression = algebra { Expression { args -> a(args) * value } } + + public override fun bindSymbolOrNull(value: String): Expression? = + super.bindSymbolOrNull(value) } public open class FunctionalExpressionExtendedField>( algebra: A, ) : FunctionalExpressionField(algebra), ExtendedField> { + public override fun number(value: Number): Expression = const(algebra.number(value)) - override fun number(value: Number): Expression = const(algebra.number(value)) + public override fun sqrt(arg: Expression): Expression = + unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) public override fun sin(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) @@ -157,6 +164,8 @@ public open class FunctionalExpressionExtendedField>( public override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = super.binaryOperationFunction(operation) + + public override fun bindSymbol(value: String): Expression = super.bindSymbol(value) } public inline fun > A.expressionInSpace(block: FunctionalExpressionGroup.() -> Expression): Expression = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index d3b65107d..a832daa14 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -337,9 +337,11 @@ public class SimpleAutoDiffExtendedField>( ) : ExtendedField>, ScaleOperations>, SimpleAutoDiffField(context, bindings) { - override fun number(value: Number): AutoDiffValue = const { number(value) } + override fun bindSymbol(value: String): AutoDiffValue = super.bindSymbol(value) - override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = a * number(value) + public override fun number(value: Number): AutoDiffValue = const { number(value) } + + public override fun scale(a: AutoDiffValue, value: Double): AutoDiffValue = a * number(value) // x ^ 2 public fun sqr(x: AutoDiffValue): AutoDiffValue = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index 2821a6648..b5aa56bd3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -120,8 +120,8 @@ public interface GroupND> : Group>, AlgebraND, b: StructureND): StructureND = @@ -141,8 +141,8 @@ public interface GroupND> : Group>, AlgebraND.plus(arg: T): StructureND = this.map { value -> add(arg, value) } @@ -159,8 +159,8 @@ public interface GroupND> : Group>, AlgebraND): StructureND = arg.map { value -> add(this@plus, value) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index d38ed02da..40d16cd91 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -17,15 +17,15 @@ public class DoubleFieldND( ScaleOperations>, ExtendedField> { - override val zero: BufferND by lazy { produce { zero } } - override val one: BufferND by lazy { produce { one } } + public override val zero: BufferND by lazy { produce { zero } } + public override val one: BufferND by lazy { produce { one } } - override fun number(value: Number): BufferND { + public override fun number(value: Number): BufferND { val d = value.toDouble() // minimize conversions return produce { d } } - override val StructureND.buffer: DoubleBuffer + public override val StructureND.buffer: DoubleBuffer get() = when { !shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException( this@DoubleFieldND.shape, @@ -36,7 +36,7 @@ public class DoubleFieldND( } @Suppress("OVERRIDE_BY_INLINE") - override inline fun StructureND.map( + public override inline fun StructureND.map( transform: DoubleField.(Double) -> Double, ): BufferND { val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) } @@ -44,7 +44,7 @@ public class DoubleFieldND( } @Suppress("OVERRIDE_BY_INLINE") - override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { + public override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND { val array = DoubleArray(strides.linearSize) { offset -> val index = strides.index(offset) DoubleField.initializer(index) @@ -53,7 +53,7 @@ public class DoubleFieldND( } @Suppress("OVERRIDE_BY_INLINE") - override inline fun StructureND.mapIndexed( + public override inline fun StructureND.mapIndexed( transform: DoubleField.(index: IntArray, Double) -> Double, ): BufferND = BufferND( strides, @@ -65,7 +65,7 @@ public class DoubleFieldND( }) @Suppress("OVERRIDE_BY_INLINE") - override inline fun combine( + public override inline fun combine( a: StructureND, b: StructureND, transform: DoubleField.(Double, Double) -> Double, @@ -76,27 +76,26 @@ public class DoubleFieldND( return BufferND(strides, buffer) } - override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } + public override fun scale(a: StructureND, value: Double): StructureND = a.map { it * value } - override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } + public override fun power(arg: StructureND, pow: Number): BufferND = arg.map { power(it, pow) } - override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } + public override fun exp(arg: StructureND): BufferND = arg.map { exp(it) } + public override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } - override fun ln(arg: StructureND): BufferND = arg.map { ln(it) } + public override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } + public override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } + public override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } + public override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } + public override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } + public override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } - override fun sin(arg: StructureND): BufferND = arg.map { sin(it) } - override fun cos(arg: StructureND): BufferND = arg.map { cos(it) } - override fun tan(arg: StructureND): BufferND = arg.map { tan(it) } - override fun asin(arg: StructureND): BufferND = arg.map { asin(it) } - override fun acos(arg: StructureND): BufferND = arg.map { acos(it) } - override fun atan(arg: StructureND): BufferND = arg.map { atan(it) } - - override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } - override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } - override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } - override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } - override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } - override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } + public override fun sinh(arg: StructureND): BufferND = arg.map { sinh(it) } + public override fun cosh(arg: StructureND): BufferND = arg.map { cosh(it) } + public override fun tanh(arg: StructureND): BufferND = arg.map { tanh(it) } + public override fun asinh(arg: StructureND): BufferND = arg.map { asinh(it) } + public override fun acosh(arg: StructureND): BufferND = arg.map { acosh(it) } + public override fun atanh(arg: StructureND): BufferND = arg.map { atanh(it) } } public fun AlgebraND.Companion.real(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt index 78ada6f5c..1b84b2c63 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/Algebra.kt @@ -119,8 +119,8 @@ public interface GroupOperations : Algebra { /** * Addition of two elements. * - * @param a the addend. - * @param b the augend. + * @param a the augend. + * @param b the addend. * @return the sum. */ public fun add(a: T, b: T): T @@ -146,8 +146,8 @@ public interface GroupOperations : Algebra { /** * Addition of two elements. * - * @receiver the addend. - * @param b the augend. + * @receiver the augend. + * @param b the addend. * @return the sum. */ public operator fun T.plus(b: T): T = add(this, b) @@ -293,5 +293,5 @@ public interface FieldOperations : RingOperations { * @param T the type of element of this field. */ public interface Field : Ring, FieldOperations, ScaleOperations, NumericAlgebra { - override fun number(value: Number): T = scale(one, value.toDouble()) -} \ No newline at end of file + public override fun number(value: Number): T = scale(one, value.toDouble()) +} diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt index b2b5911df..c0380a197 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/AlgebraElements.kt @@ -46,7 +46,8 @@ public operator fun , S : NumbersAddOperations> T.mi /** * Adds element to this one. * - * @param b the augend. + * @receiver the augend. + * @param b the addend. * @return the sum. */ public operator fun , S : Group> T.plus(b: T): T = @@ -58,11 +59,11 @@ public operator fun , S : Group> T.plus(b: T): T = //public operator fun , S : Space> Number.times(element: T): T = // element.times(this) - /** * Multiplies this element by another one. * - * @param b the multiplicand. + * @receiver the multiplicand. + * @param b the multiplier. * @return the product. */ public operator fun , R : Ring> T.times(b: T): T = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt index bd5f5951f..84d4f8064 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/NumericAlgebra.kt @@ -1,6 +1,8 @@ package space.kscience.kmath.operations import space.kscience.kmath.misc.UnstableKMathAPI +import kotlin.math.E +import kotlin.math.PI /** * An algebraic structure where elements can have numeric representation. @@ -79,8 +81,26 @@ public interface NumericAlgebra : Algebra { */ public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = rightSideNumberOperationFunction(operation)(left, right) + + public override fun bindSymbolOrNull(value: String): T? = when (value) { + "pi" -> number(PI) + "e" -> number(E) + else -> super.bindSymbolOrNull(value) + } } +/** + * The π mathematical constant. + */ +public val NumericAlgebra.pi: T + get() = bindSymbolOrNull("pi") ?: number(PI) + +/** + * The *e* mathematical constant. + */ +public val NumericAlgebra.e: T + get() = number(E) + /** * Scale by scalar operations */ @@ -131,16 +151,16 @@ public interface NumbersAddOperations : Group, NumericAlgebra { /** * Addition of element and scalar. * - * @receiver the addend. - * @param b the augend. + * @receiver the augend. + * @param b the addend. */ public operator fun T.plus(b: Number): T = this + number(b) /** * Addition of scalar and element. * - * @receiver the addend. - * @param b the augend. + * @receiver the augend. + * @param b the addend. */ public operator fun Number.plus(b: T): T = b + this diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 0101b058a..37257f0cf 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -44,6 +44,12 @@ public interface ExtendedField : ExtendedFieldOperations, Field, Numeri public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2.0 + public override fun bindSymbol(value: String): T = when (value) { + "pi" -> pi + "e" -> e + else -> super.bindSymbol(value) + } + public override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = when (operation) { PowerOperations.POW_OPERATION -> ::power @@ -56,10 +62,10 @@ public interface ExtendedField : ExtendedFieldOperations, Field, Numeri */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object DoubleField : ExtendedField, Norm, ScaleOperations { - public override val zero: Double = 0.0 - public override val one: Double = 1.0 + public override inline val zero: Double get() = 0.0 + public override inline val one: Double get() = 1.0 - override fun number(value: Number): Double = value.toDouble() + public override fun number(value: Number): Double = value.toDouble() public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { @@ -68,13 +74,11 @@ public object DoubleField : ExtendedField, Norm, ScaleOp } public override inline fun add(a: Double, b: Double): Double = a + b -// public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() -// override fun divide(a: Double, k: Number): Double = a / k.toDouble() public override inline fun multiply(a: Double, b: Double): Double = a * b public override inline fun divide(a: Double, b: Double): Double = a / b - override fun scale(a: Double, value: Double): Double = a * value + public override fun scale(a: Double, value: Double): Double = a * value public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) @@ -108,10 +112,10 @@ public object DoubleField : ExtendedField, Norm, ScaleOp */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object FloatField : ExtendedField, Norm { - public override val zero: Float = 0.0f - public override val one: Float = 1.0f + public override inline val zero: Float get() = 0.0f + public override inline val one: Float get() = 1.0f - override fun number(value: Number): Float = value.toFloat() + public override fun number(value: Number): Float = value.toFloat() public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = when (operation) { @@ -120,7 +124,7 @@ public object FloatField : ExtendedField, Norm { } public override inline fun add(a: Float, b: Float): Float = a + b - override fun scale(a: Float, value: Double): Float = a * value.toFloat() + public override fun scale(a: Float, value: Double): Float = a * value.toFloat() public override inline fun multiply(a: Float, b: Float): Float = a * b @@ -158,13 +162,13 @@ public object FloatField : ExtendedField, Norm { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object IntRing : Ring, Norm, NumericAlgebra { - public override val zero: Int + public override inline val zero: Int get() = 0 - public override val one: Int + public override inline val one: Int get() = 1 - override fun number(value: Number): Int = value.toInt() + public override fun number(value: Number): Int = value.toInt() public override inline fun add(a: Int, b: Int): Int = a + b public override inline fun multiply(a: Int, b: Int): Int = a * b public override inline fun norm(arg: Int): Int = abs(arg) @@ -180,13 +184,13 @@ public object IntRing : Ring, Norm, NumericAlgebra { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object ShortRing : Ring, Norm, NumericAlgebra { - public override val zero: Short + public override inline val zero: Short get() = 0 - public override val one: Short + public override inline val one: Short get() = 1 - override fun number(value: Number): Short = value.toShort() + public override fun number(value: Number): Short = value.toShort() public override inline fun add(a: Short, b: Short): Short = (a + b).toShort() public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() @@ -202,13 +206,13 @@ public object ShortRing : Ring, Norm, NumericAlgebra */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object ByteRing : Ring, Norm, NumericAlgebra { - public override val zero: Byte + public override inline val zero: Byte get() = 0 - public override val one: Byte + public override inline val one: Byte get() = 1 - override fun number(value: Number): Byte = value.toByte() + public override fun number(value: Number): Byte = value.toByte() public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() @@ -224,13 +228,13 @@ public object ByteRing : Ring, Norm, NumericAlgebra { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object LongRing : Ring, Norm, NumericAlgebra { - public override val zero: Long + public override inline val zero: Long get() = 0L - public override val one: Long + public override inline val one: Long get() = 1L - override fun number(value: Number): Long = value.toLong() + public override fun number(value: Number): Long = value.toLong() public override inline fun add(a: Long, b: Long): Long = a + b public override inline fun multiply(a: Long, b: Long): Long = a * b public override fun norm(arg: Long): Long = abs(arg) diff --git a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt index 3c9d6a2e4..908c5775b 100644 --- a/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/space/kscience/kmath/stat/StatisticTest.kt @@ -3,7 +3,6 @@ package space.kscience.kmath.stat import kotlinx.coroutines.flow.drop import kotlinx.coroutines.flow.first import kotlinx.coroutines.runBlocking - import space.kscience.kmath.streaming.chunked import kotlin.test.Test diff --git a/kmath-viktor/api/kmath-viktor.api b/kmath-viktor/api/kmath-viktor.api index e209c863c..0e4eac77e 100644 --- a/kmath-viktor/api/kmath-viktor.api +++ b/kmath-viktor/api/kmath-viktor.api @@ -126,7 +126,7 @@ public final class space/kscience/kmath/viktor/ViktorFieldND : space/kscience/km public synthetic fun sqrt (Ljava/lang/Object;)Ljava/lang/Object; public fun sqrt (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public synthetic fun tan (Ljava/lang/Object;)Ljava/lang/Object; - public fun tan (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; + public fun tan-8UOKELU (Lspace/kscience/kmath/nd/StructureND;)Lorg/jetbrains/bio/viktor/F64Array; public synthetic fun tanh (Ljava/lang/Object;)Ljava/lang/Object; public fun tanh (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; public fun times (DLspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND; diff --git a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt index 420bcac90..49cd3ebd9 100644 --- a/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt +++ b/kmath-viktor/src/main/kotlin/space/kscience/kmath/viktor/ViktorStructureND.kt @@ -41,7 +41,6 @@ public class ViktorFieldND(public override val shape: IntArray) : FieldND.unaryMinus(): StructureND = -1 * this + public override fun StructureND.unaryMinus(): StructureND = -1 * this public override fun StructureND.map(transform: DoubleField.(Double) -> Double): ViktorStructureND = F64Array(*this@ViktorFieldND.shape).apply { @@ -100,24 +99,21 @@ public class ViktorFieldND(public override val shape: IntArray) : FieldND.plus(arg: Double): ViktorStructureND = (f64Buffer.plus(arg)).asStructure() - override fun number(value: Number): ViktorStructureND = + public override fun number(value: Number): ViktorStructureND = F64Array.full(init = value.toDouble(), shape = shape).asStructure() - override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } + public override fun sin(arg: StructureND): ViktorStructureND = arg.map { sin(it) } + public override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } + public override fun tan(arg: StructureND): ViktorStructureND = arg.map { tan(it) } + public override fun asin(arg: StructureND): ViktorStructureND = arg.map { asin(it) } + public override fun acos(arg: StructureND): ViktorStructureND = arg.map { acos(it) } + public override fun atan(arg: StructureND): ViktorStructureND = arg.map { atan(it) } - override fun cos(arg: StructureND): ViktorStructureND = arg.map { cos(it) } + public override fun power(arg: StructureND, pow: Number): ViktorStructureND = arg.map { it.pow(pow) } - override fun asin(arg: StructureND): ViktorStructureND = arg.map { asin(it) } + public override fun exp(arg: StructureND): ViktorStructureND = arg.f64Buffer.exp().asStructure() - override fun acos(arg: StructureND): ViktorStructureND = arg.map { acos(it) } - - override fun atan(arg: StructureND): ViktorStructureND = arg.map { atan(it) } - - override fun power(arg: StructureND, pow: Number): ViktorStructureND = arg.map { it.pow(pow) } - - override fun exp(arg: StructureND): ViktorStructureND = arg.f64Buffer.exp().asStructure() - - override fun ln(arg: StructureND): ViktorStructureND = arg.f64Buffer.log().asStructure() + public override fun ln(arg: StructureND): ViktorStructureND = arg.f64Buffer.log().asStructure() } -public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) \ No newline at end of file +public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape) From 45301d9172c5b02792288f2b50d1496be22f4c5a Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 6 Apr 2021 11:15:47 +0300 Subject: [PATCH 8/9] Update build.yml Add timeout to build --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 25f2cfd0d..f39e12a12 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,6 +8,7 @@ jobs: matrix: os: [ macOS-latest, windows-latest ] runs-on: ${{matrix.os}} + timeout-minutes: 30 steps: - name: Checkout the repo uses: actions/checkout@v2 From 5bdc02d18cae34ea05e4931a4714ddb63f2950d5 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 6 Apr 2021 17:17:43 +0300 Subject: [PATCH 9/9] fix for #272 --- CHANGELOG.md | 1 + build.gradle.kts | 2 +- .../kscience/kmath/kotlingrad/DifferentiableMstExpression.kt | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4d3b93e9..fdace591c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Field extends ScaleOperations - Basic integration API - Basic MPP distributions and samplers +- bindSymbolOrNull ### Changed - Exponential operations merged with hyperbolic functions diff --git a/build.gradle.kts b/build.gradle.kts index cc863a957..59e93e67f 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ allprojects { } group = "space.kscience" - version = "0.3.0-dev-4" + version = "0.3.0-dev-5" } subprojects { diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt index d5b55e031..ab3547cda 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/DifferentiableMstExpression.kt @@ -5,7 +5,6 @@ import space.kscience.kmath.ast.MST import space.kscience.kmath.ast.MstAlgebra import space.kscience.kmath.ast.interpret import space.kscience.kmath.expressions.DifferentiableExpression -import space.kscience.kmath.expressions.Expression import space.kscience.kmath.misc.Symbol import space.kscience.kmath.operations.NumericAlgebra @@ -22,7 +21,7 @@ import space.kscience.kmath.operations.NumericAlgebra public class DifferentiableMstExpression>( public val algebra: A, public val mst: MST, -) : DifferentiableExpression> { +) : DifferentiableExpression> { public override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments)