Merge branch 'dev' into feature/tensor-algebra

This commit is contained in:
Roland Grinis 2021-04-07 12:05:56 +01:00
commit 383376080e
112 changed files with 3154 additions and 1398 deletions

View File

@ -8,6 +8,7 @@ jobs:
matrix:
os: [ macOS-latest, windows-latest ]
runs-on: ${{matrix.os}}
timeout-minutes: 30
steps:
- name: Checkout the repo
uses: actions/checkout@v2

View File

@ -5,6 +5,8 @@
- ScaleOperations interface
- Field extends ScaleOperations
- Basic integration API
- Basic MPP distributions and samplers
- bindSymbolOrNull
### Changed
- Exponential operations merged with hyperbolic functions
@ -14,6 +16,8 @@
- NDStructure and NDAlgebra to StructureND and AlgebraND respectively
- Real -> Double
- DataSets are moved from functions to core
- Redesign advanced Chain API
- Redesign MST. Remove MSTExpression.
### Deprecated
@ -21,6 +25,7 @@
- Nearest in Domain. To be implemented in geometry package.
- Number multiplication and division in main Algebra chain
- `contentEquals` from Buffer. It moved to the companion.
- MSTExpression
### Fixed

View File

@ -259,8 +259,8 @@ repositories {
}
dependencies {
api("space.kscience:kmath-core:0.3.0-dev-3")
// api("kscience.kmath:kmath-core-jvm:0.3.0-dev-3") for jvm-specific version
api("space.kscience:kmath-core:0.3.0-dev-4")
// api("kscience.kmath:kmath-core-jvm:0.3.0-dev-4") for jvm-specific version
}
```

View File

@ -18,7 +18,7 @@ allprojects {
}
group = "space.kscience"
version = "0.3.0-dev-4"
version = "0.3.0-dev-5"
}
subprojects {

View File

@ -4,14 +4,16 @@ import kotlinx.benchmark.Benchmark
import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope
import kotlinx.benchmark.State
import space.kscience.kmath.asm.compile
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.ast.MstField
import space.kscience.kmath.ast.toExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.expressionInField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.random.Random
@State(Scope.Benchmark)
@ -28,20 +30,20 @@ internal class ExpressionsInterpretersBenchmark {
@Benchmark
fun mstExpression(blackhole: Blackhole) {
val expr = algebra.mstInField {
val expr = MstField {
val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0
}
}.toExpression(algebra)
invokeAndSum(expr, blackhole)
}
@Benchmark
fun asmExpression(blackhole: Blackhole) {
val expr = algebra.mstInField {
val expr = MstField {
val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0
}.compile()
}.compileToExpression(algebra)
invokeAndSum(expr, blackhole)
}

View File

@ -0,0 +1,21 @@
package space.kscience.kmath.ast
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess
import space.kscience.kmath.ast.rendering.LatexSyntaxRenderer
import space.kscience.kmath.ast.rendering.MathMLSyntaxRenderer
import space.kscience.kmath.ast.rendering.renderWithStringBuilder
public fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
println("MathSyntax:")
println(syntax)
println()
val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax)
println("LaTeX:")
println(latex)
println()
val mathML = MathMLSyntaxRenderer.renderWithStringBuilder(syntax)
println("MathML:")
println(mathML)
}

View File

@ -1,15 +1,17 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
fun main() {
val expr = DoubleField.mstInField {
val x = bindSymbol("x")
val expr = MstField {
val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0
}
repeat(10000000) {
expr.invoke("x" to 1.0)
expr.interpret(DoubleField, x to 1.0)
}
}

View File

@ -1,9 +1,9 @@
package space.kscience.kmath.ast
import space.kscience.kmath.asm.compile
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.kotlingrad.differentiable
import space.kscience.kmath.kotlingrad.toDiffExpression
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
@ -14,11 +14,11 @@ import space.kscience.kmath.operations.DoubleField
fun main() {
val x by symbol
val actualDerivative = MstExpression(DoubleField, "x^2-4*x-44".parseMath())
.differentiable()
val actualDerivative = "x^2-4*x-44".parseMath()
.toDiffExpression(DoubleField)
.derivative(x)
.compile()
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
}

View File

@ -7,6 +7,7 @@ import kscience.plotly.models.ScatterMode
import kscience.plotly.models.TraceValues
import space.kscience.kmath.commons.optimization.chiSquared
import space.kscience.kmath.commons.optimization.minimize
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.optimization.FunctionOptimization
import space.kscience.kmath.optimization.OptimizationResult
@ -14,7 +15,6 @@ import space.kscience.kmath.real.DoubleVector
import space.kscience.kmath.real.map
import space.kscience.kmath.real.step
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.distributions.NormalDistribution
import space.kscience.kmath.structures.asIterable
import space.kscience.kmath.structures.toList
import kotlin.math.pow
@ -62,8 +62,8 @@ suspend fun main() {
// compute differentiable chi^2 sum for given model ax^2 + bx + c
val chi2 = FunctionOptimization.chiSquared(x, y, yErr) { x1 ->
//bind variables to autodiff context
val a = bind(a)
val b = bind(b)
val a = bindSymbol(a)
val b = bindSymbol(b)
//Include default value for c if it is not provided as a parameter
val c = bindSymbolOrNull(c) ?: one
a * x1.pow(2) + b * x1 + c

View File

@ -3,8 +3,8 @@ package space.kscience.kmath.stat
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import space.kscience.kmath.stat.samplers.GaussianSampler
import org.apache.commons.rng.simple.RandomSource
import space.kscience.kmath.samplers.GaussianSampler
import java.time.Duration
import java.time.Instant
import org.apache.commons.rng.sampling.distribution.GaussianSampler as CMGaussianSampler
@ -12,8 +12,8 @@ import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSa
private suspend fun runKMathChained(): Duration {
val generator = RandomGenerator.fromSource(RandomSource.MT, 123L)
val normal = GaussianSampler.of(7.0, 2.0)
val chain = normal.sample(generator).blocking()
val normal = GaussianSampler(7.0, 2.0)
val chain = normal.sample(generator)
val startTime = Instant.now()
var sum = 0.0

View File

@ -3,7 +3,7 @@ package space.kscience.kmath.stat
import kotlinx.coroutines.runBlocking
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.collectWithState
import space.kscience.kmath.stat.distributions.NormalDistribution
import space.kscience.kmath.distributions.NormalDistribution
/**
* The state of distribution averager.

View File

@ -1,6 +1,5 @@
package space.kscience.kmath.structures
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField
@ -9,12 +8,10 @@ import java.util.*
import java.util.stream.IntStream
/**
* A demonstration implementation of NDField over Real using Java [DoubleStream] for parallel execution
* A demonstration implementation of NDField over Real using Java [java.util.stream.DoubleStream] for parallel
* execution.
*/
@OptIn(UnstableKMathAPI::class)
class StreamDoubleFieldND(
override val shape: IntArray,
) : FieldND<Double, DoubleField>,
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>,
ExtendedField<StructureND<Double>> {
@ -38,7 +35,6 @@ class StreamDoubleFieldND(
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
}
override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
val index = strides.index(offset)
@ -104,4 +100,4 @@ class StreamDoubleFieldND(
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
}
fun AlgebraND.Companion.realWithStream(vararg shape: Int): StreamDoubleFieldND = StreamDoubleFieldND(shape)
fun AlgebraND.Companion.realWithStream(vararg shape: Int): StreamDoubleFieldND = StreamDoubleFieldND(shape)

View File

@ -12,7 +12,7 @@ Abstract syntax tree expression representation and related optimizations.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-ast:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -23,7 +23,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-ast:0.3.0-dev-3'
implementation 'space.kscience:kmath-ast:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -35,7 +35,7 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-ast:0.3.0-dev-3")
implementation("space.kscience:kmath-ast:0.3.0-dev-4")
}
```
@ -111,3 +111,39 @@ var executable = function (constants, arguments) {
#### Known issues
- This feature uses `eval` which can be unavailable in several environments.
## Rendering expressions
kmath-ast also includes an extensible engine to display expressions in LaTeX or MathML syntax.
Example usage:
```kotlin
import space.kscience.kmath.ast.*
import space.kscience.kmath.ast.rendering.*
public fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax)
println("LaTeX:")
println(latex)
println()
val mathML = MathMLSyntaxRenderer.renderWithStringBuilder(syntax)
println("MathML:")
println(mathML)
}
```
Result LaTeX:
![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D)
Result MathML (embedding MathML is not allowed by GitHub Markdown):
```html
<mrow><msup><mrow><mi>e</mi></mrow><mrow><msqrt><mi>x</mi></msqrt></mrow></msup><mo>-</mo><mfrac><mrow><mfrac><mrow><msup><mrow><mo>sin</mo></mrow><mrow><mo>-</mo><mn>1</mn></mrow></msup><mspace width="0.167em"></mspace><mfenced open="(" close=")" separators=""><mn>2</mn><mspace width="0.167em"></mspace><mi>x</mi></mfenced></mrow><mrow><mn>2</mn><mo>&times;</mo><msup><mrow><mn>10</mn></mrow><mrow><mn>10</mn></mrow></msup><mo>+</mo><msup><mrow><mi>x</mi></mrow><mrow><mn>3</mn></mrow></msup></mrow></mfrac></mrow><mrow><mo>-</mo><mn>12</mn></mrow></mfrac></mrow>
```
It is also possible to create custom algorithms of render, and even add support of other markup languages
(see API reference).

View File

@ -78,3 +78,39 @@ var executable = function (constants, arguments) {
#### Known issues
- This feature uses `eval` which can be unavailable in several environments.
## Rendering expressions
kmath-ast also includes an extensible engine to display expressions in LaTeX or MathML syntax.
Example usage:
```kotlin
import space.kscience.kmath.ast.*
import space.kscience.kmath.ast.rendering.*
public fun main() {
val mst = "exp(sqrt(x))-asin(2*x)/(2e10+x^3)/(-12)".parseMath()
val syntax = FeaturedMathRendererWithPostProcess.Default.render(mst)
val latex = LatexSyntaxRenderer.renderWithStringBuilder(syntax)
println("LaTeX:")
println(latex)
println()
val mathML = MathMLSyntaxRenderer.renderWithStringBuilder(syntax)
println("MathML:")
println(mathML)
}
```
Result LaTeX:
![](http://chart.googleapis.com/chart?cht=tx&chl=e%5E%7B%5Csqrt%7Bx%7D%7D-%5Cfrac%7B%5Cfrac%7B%5Coperatorname%7Bsin%7D%5E%7B-1%7D%5C,%5Cleft(2%5C,x%5Cright)%7D%7B2%5Ctimes10%5E%7B10%7D%2Bx%5E%7B3%7D%7D%7D%7B-12%7D)
Result MathML (embedding MathML is not allowed by GitHub Markdown):
```html
<mrow><msup><mrow><mi>e</mi></mrow><mrow><msqrt><mi>x</mi></msqrt></mrow></msup><mo>-</mo><mfrac><mrow><mfrac><mrow><msup><mrow><mo>sin</mo></mrow><mrow><mo>-</mo><mn>1</mn></mrow></msup><mspace width="0.167em"></mspace><mfenced open="(" close=")" separators=""><mn>2</mn><mspace width="0.167em"></mspace><mi>x</mi></mfenced></mrow><mrow><mn>2</mn><mo>&times;</mo><msup><mrow><mn>10</mn></mrow><mrow><mn>10</mn></mrow></msup><mo>+</mo><msup><mrow><mi>x</mi></mrow><mrow><mn>3</mn></mrow></msup></mrow></mfrac></mrow><mrow><mo>-</mo><mn>12</mn></mrow></mfrac></mrow>
```
It is also possible to create custom algorithms of render, and even add support of other markup languages
(see API reference).

View File

@ -1,5 +1,8 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
@ -76,11 +79,47 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
}
}
internal class InnerAlgebra<T : Any>(val algebra: Algebra<T>, val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbolOrNull(value: String): T? = algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)]
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
/**
* Interprets the [MST] node with this [Algebra].
* Interprets the [MST] node with this [Algebra] and optional [arguments]
*/
public fun <T : Any> MST.interpret(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
InnerAlgebra(algebra, arguments).evaluate(this)
/**
* Interprets the [MST] node with this [Algebra] and optional [arguments]
*
* @receiver the node to evaluate.
* @param algebra the algebra that provides operations.
* @return the value of expression.
*/
public fun <T> MST.interpret(algebra: Algebra<T>): T = algebra.evaluate(this)
public fun <T : Any> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol, T>): T =
interpret(algebra, mapOf(*arguments))
/**
* Interpret this [MST] as expression.
*/
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
interpret(algebra, arguments)
}

View File

@ -8,7 +8,8 @@ import space.kscience.kmath.operations.*
*/
public object MstAlgebra : NumericAlgebra<MST> {
public override fun number(value: Number): MST.Numeric = MST.Numeric(value)
public override fun bindSymbol(value: String): MST.Symbolic = MST.Symbolic(value)
public override fun bindSymbolOrNull(value: String): MST.Symbolic = MST.Symbolic(value)
override fun bindSymbol(value: String): MST.Symbolic = bindSymbolOrNull(value)
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
{ arg -> MST.Unary(operation, arg) }
@ -24,7 +25,7 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
public override val zero: MST.Numeric = number(0.0)
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value)
public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value)
public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b)
public override operator fun MST.unaryPlus(): MST.Unary =
unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this)
@ -48,13 +49,14 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
/**
* [Ring] over [MST] nodes.
*/
@Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class)
public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
public override val zero: MST.Numeric get() = MstGroup.zero
public override inline val zero: MST.Numeric get() = MstGroup.zero
public override val one: MST.Numeric = number(1.0)
public override fun number(value: Number): MST.Numeric = MstGroup.number(value)
public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value)
public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value)
public override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
public override fun scale(a: MST, value: Double): MST.Binary =
@ -77,13 +79,13 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
/**
* [Field] over [MST] nodes.
*/
@Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class)
public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
public override val zero: MST.Numeric get() = MstRing.zero
public override inline val zero: MST.Numeric get() = MstRing.zero
public override inline val one: MST.Numeric get() = MstRing.one
public override val one: MST.Numeric get() = MstRing.one
public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value)
public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value)
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
@ -108,11 +110,12 @@ public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<
/**
* [ExtendedField] over [MST] nodes.
*/
@Suppress("OVERRIDE_BY_INLINE")
public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
public override val zero: MST.Numeric get() = MstField.zero
public override val one: MST.Numeric get() = MstField.one
public override inline val zero: MST.Numeric get() = MstField.zero
public override inline val one: MST.Numeric get() = MstField.one
public override fun bindSymbol(value: String): MST.Symbolic = MstAlgebra.bindSymbol(value)
public override fun bindSymbolOrNull(value: String): MST.Symbolic = MstAlgebra.bindSymbolOrNull(value)
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg)
public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg)

View File

@ -1,138 +0,0 @@
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.*
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
/**
* The expression evaluates MST on-flight. Should be much faster than functional expression, but slower than
* ASM-generated expressions.
*
* @property algebra the algebra that provides operations.
* @property mst the [MST] node.
* @author Alexander Nozik
*/
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun bindSymbol(value: String): T = try {
algebra.bindSymbol(value)
} catch (ignored: IllegalStateException) {
null
} ?: arguments.getValue(StringSymbol(value))
override fun unaryOperation(operation: String, arg: T): T =
algebra.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right)
override fun unaryOperationFunction(operation: String): (arg: T) -> T =
algebra.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T =
algebra.binaryOperationFunction(operation)
@Suppress("UNCHECKED_CAST")
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else
error("Numeric nodes are not supported by $this")
}
override operator fun invoke(arguments: Map<Symbol, T>): T = InnerAlgebra(arguments).evaluate(mst)
}
/**
* Builds [MstExpression] over [Algebra].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E,
block: E.() -> MST,
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
/**
* Builds [MstExpression] over [Group].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Group<T>> A.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstGroup.block())
}
/**
* Builds [MstExpression] over [Ring].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block())
}
/**
* Builds [MstExpression] over [Field].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block())
}
/**
* Builds [MstExpression] over [ExtendedField].
*
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block())
}
/**
* Builds [MstExpression] over [FunctionalExpressionGroup].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Group<T>> FunctionalExpressionGroup<T, A>.mstInGroup(block: MstGroup.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInGroup(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionRing].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionField].
*
* @author Alexander Nozik
*/
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block)
}
/**
* Builds [MstExpression] over [FunctionalExpressionExtendedField].
*
* @author Iaroslav Postovalov
*/
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST,
): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block)
}

View File

@ -0,0 +1,129 @@
package space.kscience.kmath.ast.rendering
/**
* [SyntaxRenderer] implementation for LaTeX.
*
* The generated string is a valid LaTeX fragment to be used in the Math Mode.
*
* Example usage:
*
* ```
* \documentclass{article}
* \begin{document}
* \begin{equation}
* %code generated by the syntax renderer
* \end{equation}
* \end{document}
* ```
*
* @author Iaroslav Postovalov
*/
public object LatexSyntaxRenderer : SyntaxRenderer {
public override fun render(node: MathSyntax, output: Appendable): Unit = output.run {
fun render(syntax: MathSyntax) = render(syntax, output)
when (node) {
is NumberSyntax -> append(node.string)
is SymbolSyntax -> append(node.string)
is OperatorNameSyntax -> {
append("\\operatorname{")
append(node.name)
append('}')
}
is SpecialSymbolSyntax -> when (node.kind) {
SpecialSymbolSyntax.Kind.INFINITY -> append("\\infty")
SpecialSymbolSyntax.Kind.SMALL_PI -> append("\\pi")
}
is OperandSyntax -> {
if (node.parentheses) append("\\left(")
render(node.operand)
if (node.parentheses) append("\\right)")
}
is UnaryOperatorSyntax -> {
render(node.prefix)
append("\\,")
render(node.operand)
}
is UnaryPlusSyntax -> {
append('+')
render(node.operand)
}
is UnaryMinusSyntax -> {
append('-')
render(node.operand)
}
is RadicalSyntax -> {
append("\\sqrt")
append('{')
render(node.operand)
append('}')
}
is SuperscriptSyntax -> {
render(node.left)
append("^{")
render(node.right)
append('}')
}
is SubscriptSyntax -> {
render(node.left)
append("_{")
render(node.right)
append('}')
}
is BinaryOperatorSyntax -> {
render(node.prefix)
append("\\left(")
render(node.left)
append(',')
render(node.right)
append("\\right)")
}
is BinaryPlusSyntax -> {
render(node.left)
append('+')
render(node.right)
}
is BinaryMinusSyntax -> {
render(node.left)
append('-')
render(node.right)
}
is FractionSyntax -> {
append("\\frac{")
render(node.left)
append("}{")
render(node.right)
append('}')
}
is RadicalWithIndexSyntax -> {
append("\\sqrt")
append('[')
render(node.left)
append(']')
append('{')
render(node.right)
append('}')
}
is MultiplicationSyntax -> {
render(node.left)
append(if (node.times) "\\times" else "\\,")
render(node.right)
}
}
}
}

View File

@ -0,0 +1,134 @@
package space.kscience.kmath.ast.rendering
/**
* [SyntaxRenderer] implementation for MathML.
*
* The generated XML string is a valid MathML instance.
*
* @author Iaroslav Postovalov
*/
public object MathMLSyntaxRenderer : SyntaxRenderer {
public override fun render(node: MathSyntax, output: Appendable) {
output.append("<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>")
render0(node, output)
output.append("</mrow></math>")
}
private fun render0(node: MathSyntax, output: Appendable): Unit = output.run {
fun tag(tagName: String, vararg attr: Pair<String, String>, block: () -> Unit = {}) {
append('<')
append(tagName)
if (attr.isNotEmpty()) {
append(' ')
var count = 0
for ((name, value) in attr) {
if (++count > 1) append(' ')
append(name)
append("=\"")
append(value)
append('"')
}
}
append('>')
block()
append("</")
append(tagName)
append('>')
}
fun render(syntax: MathSyntax) = render0(syntax, output)
when (node) {
is NumberSyntax -> tag("mn") { append(node.string) }
is SymbolSyntax -> tag("mi") { append(node.string) }
is OperatorNameSyntax -> tag("mo") { append(node.name) }
is SpecialSymbolSyntax -> when (node.kind) {
SpecialSymbolSyntax.Kind.INFINITY -> tag("mo") { append("&infin;") }
SpecialSymbolSyntax.Kind.SMALL_PI -> tag("mo") { append("&pi;") }
}
is OperandSyntax -> if (node.parentheses) {
tag("mfenced", "open" to "(", "close" to ")", "separators" to "") {
render(node.operand)
}
} else {
render(node.operand)
}
is UnaryOperatorSyntax -> {
render(node.prefix)
tag("mspace", "width" to "0.167em")
render(node.operand)
}
is UnaryPlusSyntax -> {
tag("mo") { append('+') }
render(node.operand)
}
is UnaryMinusSyntax -> {
tag("mo") { append("-") }
render(node.operand)
}
is RadicalSyntax -> tag("msqrt") { render(node.operand) }
is SuperscriptSyntax -> tag("msup") {
tag("mrow") { render(node.left) }
tag("mrow") { render(node.right) }
}
is SubscriptSyntax -> tag("msub") {
tag("mrow") { render(node.left) }
tag("mrow") { render(node.right) }
}
is BinaryOperatorSyntax -> {
render(node.prefix)
tag("mfenced", "open" to "(", "close" to ")", "separators" to "") {
render(node.left)
tag("mo") { append(',') }
render(node.right)
}
}
is BinaryPlusSyntax -> {
render(node.left)
tag("mo") { append('+') }
render(node.right)
}
is BinaryMinusSyntax -> {
render(node.left)
tag("mo") { append('-') }
render(node.right)
}
is FractionSyntax -> tag("mfrac") {
tag("mrow") {
render(node.left)
}
tag("mrow") {
render(node.right)
}
}
is RadicalWithIndexSyntax -> tag("mroot") {
tag("mrow") { render(node.right) }
tag("mrow") { render(node.left) }
}
is MultiplicationSyntax -> {
render(node.left)
if (node.times) tag("mo") { append("&times;") } else tag("mspace", "width" to "0.167em")
render(node.right)
}
}
}
}

View File

@ -0,0 +1,103 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.MST
/**
* Renders [MST] to [MathSyntax].
*
* @author Iaroslav Postovalov
*/
public fun interface MathRenderer {
/**
* Renders [MST] to [MathSyntax].
*/
public fun render(mst: MST): MathSyntax
}
/**
* Implements [MST] render process with sequence of features.
*
* @property features The applied features.
* @author Iaroslav Postovalov
*/
public open class FeaturedMathRenderer(public val features: List<RenderFeature>) : MathRenderer {
public override fun render(mst: MST): MathSyntax {
for (feature in features) feature.render(this, mst)?.let { return it }
throw UnsupportedOperationException("Renderer $this has no appropriate feature to render node $mst.")
}
/**
* Logical unit of [MST] rendering.
*/
public fun interface RenderFeature {
/**
* Renders [MST] to [MathSyntax] in the context of owning renderer.
*/
public fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax?
}
}
/**
* Extends [FeaturedMathRenderer] by adding post-processing stages.
*
* @property stages The applied stages.
* @author Iaroslav Postovalov
*/
public open class FeaturedMathRendererWithPostProcess(
features: List<RenderFeature>,
public val stages: List<PostProcessStage>,
) : FeaturedMathRenderer(features) {
public override fun render(mst: MST): MathSyntax {
val res = super.render(mst)
for (stage in stages) stage.perform(res)
return res
}
/**
* Logical unit of [MathSyntax] post-processing.
*/
public fun interface PostProcessStage {
/**
* Performs the specified action over [MathSyntax].
*/
public fun perform(node: MathSyntax)
}
public companion object {
/**
* The default setup of [FeaturedMathRendererWithPostProcess].
*/
public val Default: FeaturedMathRendererWithPostProcess = FeaturedMathRendererWithPostProcess(
listOf(
// Printing known operations
BinaryPlus.Default,
BinaryMinus.Default,
UnaryPlus.Default,
UnaryMinus.Default,
Multiplication.Default,
Fraction.Default,
Power.Default,
SquareRoot.Default,
Exponential.Default,
InverseTrigonometricOperations.Default,
// Fallback option for unknown operations - printing them as operator
BinaryOperator.Default,
UnaryOperator.Default,
// Pretty printing for some objects
PrettyPrintFloats.Default,
PrettyPrintIntegers.Default,
PrettyPrintPi.Default,
// Printing terminal nodes as string
PrintNumeric,
PrintSymbolic,
),
listOf(
SimplifyParentheses.Default,
BetterMultiplication,
),
)
}
}

View File

@ -0,0 +1,331 @@
package space.kscience.kmath.ast.rendering
/**
* Mathematical typography syntax node.
*
* @author Iaroslav Postovalov
*/
public sealed class MathSyntax {
/**
* The parent node of this syntax node.
*/
public var parent: MathSyntax? = null
}
/**
* Terminal node, which should not have any children nodes.
*
* @author Iaroslav Postovalov
*/
public sealed class TerminalSyntax : MathSyntax()
/**
* Node containing a certain operation.
*
* @author Iaroslav Postovalov
*/
public sealed class OperationSyntax : MathSyntax() {
/**
* The operation token.
*/
public abstract val operation: String
}
/**
* Unary node, which has only one child.
*
* @author Iaroslav Postovalov
*/
public sealed class UnarySyntax : OperationSyntax() {
/**
* The operand of this node.
*/
public abstract val operand: MathSyntax
}
/**
* Binary node, which has only two children.
*
* @author Iaroslav Postovalov
*/
public sealed class BinarySyntax : OperationSyntax() {
/**
* The left-hand side operand.
*/
public abstract val left: MathSyntax
/**
* The right-hand side operand.
*/
public abstract val right: MathSyntax
}
/**
* Represents a number.
*
* @property string The digits of number.
* @author Iaroslav Postovalov
*/
public data class NumberSyntax(public var string: String) : TerminalSyntax()
/**
* Represents a symbol.
*
* @property string The symbol.
* @author Iaroslav Postovalov
*/
public data class SymbolSyntax(public var string: String) : TerminalSyntax()
/**
* Represents special typing for operator name.
*
* @property name The operator name.
* @see BinaryOperatorSyntax
* @see UnaryOperatorSyntax
* @author Iaroslav Postovalov
*/
public data class OperatorNameSyntax(public var name: String) : TerminalSyntax()
/**
* Represents a usage of special symbols.
*
* @property kind The kind of symbol.
* @author Iaroslav Postovalov
*/
public data class SpecialSymbolSyntax(public var kind: Kind) : TerminalSyntax() {
/**
* The kind of symbol.
*/
public enum class Kind {
/**
* The infinity (&infin;) symbol.
*/
INFINITY,
/**
* The Pi (&pi;) symbol.
*/
SMALL_PI;
}
}
/**
* Represents operand of a certain operator wrapped with parentheses or not.
*
* @property operand The operand.
* @property parentheses Whether the operand should be wrapped with parentheses.
* @author Iaroslav Postovalov
*/
public data class OperandSyntax(
public val operand: MathSyntax,
public var parentheses: Boolean,
) : MathSyntax() {
init {
operand.parent = this
}
}
/**
* Represents unary, prefix operator syntax (like f x).
*
* @property prefix The prefix.
* @author Iaroslav Postovalov
*/
public data class UnaryOperatorSyntax(
public override val operation: String,
public var prefix: MathSyntax,
public override val operand: OperandSyntax,
) : UnarySyntax() {
init {
operand.parent = this
}
}
/**
* Represents prefix, unary plus operator.
*
* @author Iaroslav Postovalov
*/
public data class UnaryPlusSyntax(
public override val operation: String,
public override val operand: OperandSyntax,
) : UnarySyntax() {
init {
operand.parent = this
}
}
/**
* Represents prefix, unary minus operator.
*
* @author Iaroslav Postovalov
*/
public data class UnaryMinusSyntax(
public override val operation: String,
public override val operand: OperandSyntax,
) : UnarySyntax() {
init {
operand.parent = this
}
}
/**
* Represents radical with a node inside it.
*
* @property operand The radicand.
* @author Iaroslav Postovalov
*/
public data class RadicalSyntax(
public override val operation: String,
public override val operand: MathSyntax,
) : UnarySyntax() {
init {
operand.parent = this
}
}
/**
* Represents a syntax node with superscript (usually, for exponentiation).
*
* @property left The node.
* @property right The superscript.
* @author Iaroslav Postovalov
*/
public data class SuperscriptSyntax(
public override val operation: String,
public override val left: MathSyntax,
public override val right: MathSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents a syntax node with subscript.
*
* @property left The node.
* @property right The subscript.
* @author Iaroslav Postovalov
*/
public data class SubscriptSyntax(
public override val operation: String,
public override val left: MathSyntax,
public override val right: MathSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents binary, prefix operator syntax (like f(a, b)).
*
* @property prefix The prefix.
* @author Iaroslav Postovalov
*/
public data class BinaryOperatorSyntax(
public override val operation: String,
public var prefix: MathSyntax,
public override val left: MathSyntax,
public override val right: MathSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents binary, infix addition.
*
* @param left The augend.
* @param right The addend.
* @author Iaroslav Postovalov
*/
public data class BinaryPlusSyntax(
public override val operation: String,
public override val left: OperandSyntax,
public override val right: OperandSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents binary, infix subtraction.
*
* @param left The minuend.
* @param right The subtrahend.
* @author Iaroslav Postovalov
*/
public data class BinaryMinusSyntax(
public override val operation: String,
public override val left: OperandSyntax,
public override val right: OperandSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents fraction with numerator and denominator.
*
* @property left The numerator.
* @property right The denominator.
* @author Iaroslav Postovalov
*/
public data class FractionSyntax(
public override val operation: String,
public override val left: MathSyntax,
public override val right: MathSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents radical syntax with index.
*
* @property left The index.
* @property right The radicand.
* @author Iaroslav Postovalov
*/
public data class RadicalWithIndexSyntax(
public override val operation: String,
public override val left: MathSyntax,
public override val right: MathSyntax,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}
/**
* Represents binary, infix multiplication in the form of coefficient (2 x) or with operator (x&times;2).
*
* @property left The multiplicand.
* @property right The multiplier.
* @property times whether the times (&times;) symbol should be used.
* @author Iaroslav Postovalov
*/
public data class MultiplicationSyntax(
public override val operation: String,
public override val left: OperandSyntax,
public override val right: OperandSyntax,
public var times: Boolean,
) : BinarySyntax() {
init {
left.parent = this
right.parent = this
}
}

View File

@ -0,0 +1,25 @@
package space.kscience.kmath.ast.rendering
/**
* Abstraction of writing [MathSyntax] as a string of an actual markup language. Typical implementation should
* involve traversal of MathSyntax with handling each its subtype.
*
* @author Iaroslav Postovalov
*/
public fun interface SyntaxRenderer {
/**
* Renders the [MathSyntax] to [output].
*/
public fun render(node: MathSyntax, output: Appendable)
}
/**
* Calls [SyntaxRenderer.render] with given [node] and a new [StringBuilder] instance, and returns its content.
*
* @author Iaroslav Postovalov
*/
public fun SyntaxRenderer.renderWithStringBuilder(node: MathSyntax): String {
val sb = StringBuilder()
render(node, sb)
return sb.toString()
}

View File

@ -0,0 +1,331 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.rendering.FeaturedMathRenderer.RenderFeature
import space.kscience.kmath.operations.*
import kotlin.reflect.KClass
/**
* Prints any [MST.Symbolic] as a [SymbolSyntax] containing the [MST.Symbolic.value] of it.
*
* @author Iaroslav Postovalov
*/
public object PrintSymbolic : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Symbolic) return null
return SymbolSyntax(string = node.value)
}
}
/**
* Prints any [MST.Numeric] as a [NumberSyntax] containing the [Any.toString] result of it.
*
* @author Iaroslav Postovalov
*/
public object PrintNumeric : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Numeric) return null
return NumberSyntax(string = node.value.toString())
}
}
private fun printSignedNumberString(s: String): MathSyntax {
if (s.startsWith('-'))
return UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operand = OperandSyntax(
operand = NumberSyntax(string = s.removePrefix("-")),
parentheses = true,
),
)
return NumberSyntax(string = s)
}
/**
* Special printing for numeric types which are printed in form of
* *('-'? (DIGIT+ ('.' DIGIT+)? ('E' '-'? DIGIT+)? | 'Infinity')) | 'NaN'*.
*
* @property types The suitable types.
*/
public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Numeric || node.value::class !in types) return null
val toString = node.value.toString().removeSuffix(".0")
if ('E' in toString) {
val (beforeE, afterE) = toString.split('E')
val significand = beforeE.toDouble().toString().removeSuffix(".0")
val exponent = afterE.toDouble().toString().removeSuffix(".0")
return MultiplicationSyntax(
operation = RingOperations.TIMES_OPERATION,
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
right = OperandSyntax(
operand = SuperscriptSyntax(
operation = PowerOperations.POW_OPERATION,
left = NumberSyntax(string = "10"),
right = printSignedNumberString(exponent),
),
parentheses = true,
),
times = true,
)
}
if (toString.endsWith("Infinity")) {
val infty = SpecialSymbolSyntax(SpecialSymbolSyntax.Kind.INFINITY)
if (toString.startsWith('-'))
return UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operand = OperandSyntax(operand = infty, parentheses = true),
)
return infty
}
return printSignedNumberString(toString)
}
public companion object {
/**
* The default instance containing [Float], and [Double].
*/
public val Default: PrettyPrintFloats = PrettyPrintFloats(setOf(Float::class, Double::class))
}
}
/**
* Special printing for numeric types which are printed in form of *'-'? DIGIT+*.
*
* @property types The suitable types.
*/
public class PrettyPrintIntegers(public val types: Set<KClass<out Number>>) : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Numeric || node.value::class !in types)
return null
return printSignedNumberString(node.value.toString())
}
public companion object {
/**
* The default instance containing [Byte], [Short], [Int], and [Long].
*/
public val Default: PrettyPrintIntegers =
PrettyPrintIntegers(setOf(Byte::class, Short::class, Int::class, Long::class))
}
}
/**
* Special printing for symbols meaning Pi.
*
* @property symbols The allowed symbols.
*/
public class PrettyPrintPi(public val symbols: Set<String>) : RenderFeature {
public override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Symbolic || node.value !in symbols) return null
return SpecialSymbolSyntax(kind = SpecialSymbolSyntax.Kind.SMALL_PI)
}
public companion object {
/**
* The default instance containing `pi`.
*/
public val Default: PrettyPrintPi = PrettyPrintPi(setOf("pi"))
}
}
/**
* Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is
* not [MST.Unary].
*
* @param operations the allowed operations. If `null`, any operation is accepted.
*/
public abstract class Unary(public val operations: Collection<String>?) : RenderFeature {
/**
* The actual render function.
*/
protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax?
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Unary || operations != null && node.operation !in operations) return null
return render0(renderer, node)
}
}
/**
* Abstract printing of unary operations which discards [MST] if their operation is not in [operations] or its type is
* not [MST.Binary].
*
* @property operations the allowed operations. If `null`, any operation is accepted.
*/
public abstract class Binary(public val operations: Collection<String>?) : RenderFeature {
/**
* The actual render function.
*/
protected abstract fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax?
public final override fun render(renderer: FeaturedMathRenderer, node: MST): MathSyntax? {
if (node !is MST.Binary || operations != null && node.operation !in operations) return null
return render0(renderer, node)
}
}
public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryPlusSyntax(
operation = node.operation,
left = OperandSyntax(parent.render(node.left), true),
right = OperandSyntax(parent.render(node.right), true),
)
public companion object {
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION))
}
}
public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryMinusSyntax(
operation = node.operation,
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
)
public companion object {
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION))
}
}
public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryPlusSyntax(
operation = node.operation,
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION))
}
}
public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryMinusSyntax(
operation = node.operation,
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION))
}
}
public class Fraction(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = FractionSyntax(
operation = node.operation,
left = parent.render(node.left),
right = parent.render(node.right),
)
public companion object {
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION))
}
}
public class BinaryOperator(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = BinaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(name = node.operation),
left = parent.render(node.left),
right = parent.render(node.right),
)
public companion object {
public val Default: BinaryOperator = BinaryOperator(null)
}
}
public class UnaryOperator(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax(
operation = node.operation,
prefix = OperatorNameSyntax(node.operation),
operand = OperandSyntax(parent.render(node.value), true),
)
public companion object {
public val Default: UnaryOperator = UnaryOperator(null)
}
}
public class Power(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = SuperscriptSyntax(
operation = node.operation,
left = OperandSyntax(parent.render(node.left), true),
right = OperandSyntax(parent.render(node.right), true),
)
public companion object {
public val Default: Power = Power(setOf(PowerOperations.POW_OPERATION))
}
}
public class SquareRoot(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax =
RadicalSyntax(operation = node.operation, operand = parent.render(node.value))
public companion object {
public val Default: SquareRoot = SquareRoot(setOf(PowerOperations.SQRT_OPERATION))
}
}
public class Exponential(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = SuperscriptSyntax(
operation = node.operation,
left = SymbolSyntax(string = "e"),
right = parent.render(node.value),
)
public companion object {
public val Default: Exponential = Exponential(setOf(ExponentialOperations.EXP_OPERATION))
}
}
public class Multiplication(operations: Collection<String>?) : Binary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Binary): MathSyntax = MultiplicationSyntax(
operation = node.operation,
left = OperandSyntax(operand = parent.render(node.left), parentheses = true),
right = OperandSyntax(operand = parent.render(node.right), parentheses = true),
times = true,
)
public companion object {
public val Default: Multiplication = Multiplication(setOf(
RingOperations.TIMES_OPERATION,
))
}
}
public class InverseTrigonometricOperations(operations: Collection<String>?) : Unary(operations) {
public override fun render0(parent: FeaturedMathRenderer, node: MST.Unary): MathSyntax = UnaryOperatorSyntax(
operation = node.operation,
prefix = SuperscriptSyntax(
operation = PowerOperations.POW_OPERATION,
left = OperatorNameSyntax(name = node.operation.removePrefix("a")),
right = UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION,
operand = OperandSyntax(operand = NumberSyntax(string = "1"), parentheses = true),
),
),
operand = OperandSyntax(operand = parent.render(node.value), parentheses = true),
)
public companion object {
public val Default: InverseTrigonometricOperations = InverseTrigonometricOperations(setOf(
TrigonometricOperations.ACOS_OPERATION,
TrigonometricOperations.ASIN_OPERATION,
TrigonometricOperations.ATAN_OPERATION,
ExponentialOperations.ACOSH_OPERATION,
ExponentialOperations.ASINH_OPERATION,
ExponentialOperations.ATANH_OPERATION,
))
}
}

View File

@ -0,0 +1,197 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.operations.FieldOperations
import space.kscience.kmath.operations.GroupOperations
import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations
/**
* Removes unnecessary times (&times;) symbols from [MultiplicationSyntax].
*
* @author Iaroslav Postovalov
*/
public object BetterMultiplication : FeaturedMathRendererWithPostProcess.PostProcessStage {
public override fun perform(node: MathSyntax) {
when (node) {
is NumberSyntax -> Unit
is SymbolSyntax -> Unit
is OperatorNameSyntax -> Unit
is SpecialSymbolSyntax -> Unit
is OperandSyntax -> perform(node.operand)
is UnaryOperatorSyntax -> {
perform(node.prefix)
perform(node.operand)
}
is UnaryPlusSyntax -> perform(node.operand)
is UnaryMinusSyntax -> perform(node.operand)
is RadicalSyntax -> perform(node.operand)
is SuperscriptSyntax -> {
perform(node.left)
perform(node.right)
}
is SubscriptSyntax -> {
perform(node.left)
perform(node.right)
}
is BinaryOperatorSyntax -> {
perform(node.prefix)
perform(node.left)
perform(node.right)
}
is BinaryPlusSyntax -> {
perform(node.left)
perform(node.right)
}
is BinaryMinusSyntax -> {
perform(node.left)
perform(node.right)
}
is FractionSyntax -> {
perform(node.left)
perform(node.right)
}
is RadicalWithIndexSyntax -> {
perform(node.left)
perform(node.right)
}
is MultiplicationSyntax -> {
node.times = node.right.operand is NumberSyntax && !node.right.parentheses
|| node.left.operand is NumberSyntax && node.right.operand is FractionSyntax
|| node.left.operand is NumberSyntax && node.right.operand is NumberSyntax
|| node.left.operand is NumberSyntax && node.right.operand is SuperscriptSyntax && node.right.operand.left is NumberSyntax
perform(node.left)
perform(node.right)
}
}
}
}
/**
* Removes unnecessary parentheses from [OperandSyntax].
*
* @property precedenceFunction Returns the precedence number for syntax node. Higher number is lower priority.
* @author Iaroslav Postovalov
*/
public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) -> Int) :
FeaturedMathRendererWithPostProcess.PostProcessStage {
public override fun perform(node: MathSyntax) {
when (node) {
is NumberSyntax -> Unit
is SymbolSyntax -> Unit
is OperatorNameSyntax -> Unit
is SpecialSymbolSyntax -> Unit
is OperandSyntax -> {
val isRightOfSuperscript =
(node.parent is SuperscriptSyntax) && (node.parent as SuperscriptSyntax).right === node
val precedence = precedenceFunction(node.operand)
val needParenthesesByPrecedence = when (val parent = node.parent) {
null -> false
is BinarySyntax -> {
val parentPrecedence = precedenceFunction(parent)
parentPrecedence < precedence ||
parentPrecedence == precedence && parentPrecedence != 0 && node === parent.right
}
else -> precedence > precedenceFunction(parent)
}
node.parentheses = !isRightOfSuperscript
&& (needParenthesesByPrecedence || node.parent is UnaryOperatorSyntax)
perform(node.operand)
}
is UnaryOperatorSyntax -> {
perform(node.prefix)
perform(node.operand)
}
is UnaryPlusSyntax -> perform(node.operand)
is UnaryMinusSyntax -> {
perform(node.operand)
}
is RadicalSyntax -> perform(node.operand)
is SuperscriptSyntax -> {
perform(node.left)
perform(node.right)
}
is SubscriptSyntax -> {
perform(node.left)
perform(node.right)
}
is BinaryOperatorSyntax -> {
perform(node.prefix)
perform(node.left)
perform(node.right)
}
is BinaryPlusSyntax -> {
perform(node.left)
perform(node.right)
}
is BinaryMinusSyntax -> {
perform(node.left)
perform(node.right)
}
is FractionSyntax -> {
perform(node.left)
perform(node.right)
}
is MultiplicationSyntax -> {
perform(node.left)
perform(node.right)
}
is RadicalWithIndexSyntax -> {
perform(node.left)
perform(node.right)
}
}
}
public companion object {
/**
* The default configuration of [SimplifyParentheses] where power is 1, multiplicative operations are 2,
* additive operations are 3.
*/
public val Default: SimplifyParentheses = SimplifyParentheses {
when (it) {
is TerminalSyntax -> 0
is UnarySyntax -> 2
is BinarySyntax -> when (it.operation) {
PowerOperations.POW_OPERATION -> 1
RingOperations.TIMES_OPERATION -> 3
FieldOperations.DIV_OPERATION -> 3
GroupOperations.MINUS_OPERATION -> 4
GroupOperations.PLUS_OPERATION -> 4
else -> 0
}
else -> 0
}
}
}
}

View File

@ -0,0 +1,22 @@
package space.kscisnce.kmath.ast
import space.kscience.kmath.ast.MstField
import space.kscience.kmath.ast.toExpression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
class InterpretTest {
@Test
fun interpretation(){
val expr = MstField {
val x = bindSymbol(x)
x * 2.0 + number(2.0) / x - 16.0
}.toExpression(DoubleField)
expr(x to 2.2)
}
}

View File

@ -2,10 +2,11 @@ package space.kscience.kmath.estree
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MST.*
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.estree.internal.ESTreeBuilder
import space.kscience.kmath.estree.internal.estree.BaseExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
@ -13,11 +14,7 @@ import space.kscience.kmath.operations.NumericAlgebra
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
is Symbolic -> {
val symbol = try {
algebra.bindSymbol(node.value)
} catch (ignored: IllegalStateException) {
null
}
val symbol = algebra.bindSymbolOrNull(node.value)
if (symbol != null)
constant(symbol)
@ -64,19 +61,21 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
return ESTreeBuilder<T> { visit(this@compileWith) }.instance
}
/**
* Create a compiled expression with given [MST] and given [algebra].
*/
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> = compileWith(algebra)
/**
* Compiles an [MST] to ESTree generated expression using given algebra.
*
* @author Alexander Nozik.
* Compile given MST to expression and evaluate it against [arguments]
*/
public fun <T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
mst.compileWith(this)
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments)
/**
* Optimizes performance of an [MstExpression] by compiling it into ESTree generated expression.
*
* @author Alexander Nozik.
* Compile given MST to expression and evaluate it against [arguments]
*/
public fun <T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(algebra)
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
compileToExpression(algebra).invoke(*arguments)

View File

@ -3,16 +3,19 @@ package space.kscience.kmath.estree
import space.kscience.kmath.ast.*
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeConsistencyWithInterpreter {
@Test
fun mstSpace() {
val res1 = MstGroup.mstInGroup {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
@ -23,27 +26,17 @@ internal class TestESTreeConsistencyWithInterpreter {
number(1)
) + bindSymbol("x") + zero
}("x" to MST.Numeric(2))
}
val res2 = MstGroup.mstInGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol("x") + zero
}.compile()("x" to MST.Numeric(2))
assertEquals(res1, res2)
assertEquals(
mst.interpret(MstGroup, Symbol.x to MST.Numeric(2)),
mst.compile(MstGroup, Symbol.x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val res1 = ByteRing.mstInRing {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
@ -54,62 +47,43 @@ internal class TestESTreeConsistencyWithInterpreter {
number(1)
) * number(2)
}("x" to 3.toByte())
}
val res2 = ByteRing.mstInRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ByteRing, Symbol.x to 3.toByte()),
mst.compile(ByteRing, Symbol.x to 3.toByte())
)
}
@Test
fun realField() {
val res1 = DoubleField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0)
}
val res2 = DoubleField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
assertEquals(
mst.interpret(DoubleField, Symbol.x to 2.0),
mst.compile(DoubleField, Symbol.x to 2.0)
)
}
@Test
fun complexField() {
val res1 = ComplexField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0.toComplex())
}
val res2 = ComplexField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0.toComplex())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ComplexField, Symbol.x to 2.0.toComplex()),
mst.compile(ComplexField, Symbol.x to 2.0.toComplex())
)
}
}

View File

@ -1,10 +1,9 @@
package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInExtendedField
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.mstInGroup
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
@ -12,29 +11,29 @@ import kotlin.test.assertEquals
internal class TestESTreeOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
val expression = MstExtendedField { -bindSymbol("x") }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
val expression = MstExtendedField { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
val res = MstExtendedField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
assertEquals(4.0, res)
}
@Test
fun testMultipleCalls() {
val e =
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compile()
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compileToExpression(DoubleField)
val r = Random(0)
var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) }

View File

@ -1,53 +1,63 @@
package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeSpecialization {
@Test
fun testUnaryPlus() {
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("+")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr("x" to 2.0))
}
@Test
fun testPower() {
val expr = DoubleField
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
.compile()
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.estree
import space.kscience.kmath.ast.mstInRing
import space.kscience.kmath.ast.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
internal class TestESTreeVariables {
@Test
fun testVariable() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing{ bindSymbol("x") }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr("x" to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
}

View File

@ -4,8 +4,9 @@ import space.kscience.kmath.asm.internal.AsmBuilder
import space.kscience.kmath.asm.internal.buildName
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MST.*
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.NumericAlgebra
@ -21,11 +22,7 @@ import space.kscience.kmath.operations.NumericAlgebra
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
is Symbolic -> {
val symbol = try {
algebra.bindSymbol(node.value)
} catch (ignored: IllegalStateException) {
null
}
val symbol = algebra.bindSymbolOrNull(node.value)
if (symbol != null)
loadObjectConstant(symbol as Any)
@ -70,18 +67,22 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
return AsmBuilder<T>(type, buildName(this)) { visit(this@compileWith) }.instance
}
/**
* Compiles an [MST] to ASM using given algebra.
*
* @author Alexander Nozik.
*/
public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<T> =
mst.compileWith(T::class.java, this)
/**
* Optimizes performance of an [MstExpression] using ASM codegen.
*
* @author Alexander Nozik.
* Create a compiled expression with given [MST] and given [algebra].
*/
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(T::class.java, algebra)
public inline fun <reified T: Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> =
compileWith(T::class.java, algebra)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, arguments: Map<Symbol, T>): T =
compileToExpression(algebra).invoke(arguments)
/**
* Compile given MST to expression and evaluate it against [arguments]
*/
public inline fun <reified T: Any> MST.compile(algebra: Algebra<T>, vararg arguments: Pair<Symbol,T>): T =
compileToExpression(algebra).invoke(*arguments)

View File

@ -21,7 +21,8 @@ import space.kscience.kmath.operations.RingOperations
/**
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
*
* @author Alexander Nozik and Iaroslav Postovalov
* @author Alexander Nozik
* @author Iaroslav Postovalov
*/
public object ArithmeticsEvaluator : Grammar<MST>() {
// TODO replace with "...".toRegex() when better-parse 0.4.1 is released

View File

@ -3,16 +3,19 @@ package space.kscience.kmath.asm
import space.kscience.kmath.ast.*
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmConsistencyWithInterpreter {
@Test
fun mstSpace() {
val res1 = MstGroup.mstInGroup {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
@ -23,27 +26,17 @@ internal class TestAsmConsistencyWithInterpreter {
number(1)
) + bindSymbol("x") + zero
}("x" to MST.Numeric(2))
}
val res2 = MstGroup.mstInGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol("x") + zero
}.compile()("x" to MST.Numeric(2))
assertEquals(res1, res2)
assertEquals(
mst.interpret(MstGroup, x to MST.Numeric(2)),
mst.compile(MstGroup, x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val res1 = ByteRing.mstInRing {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
@ -54,62 +47,43 @@ internal class TestAsmConsistencyWithInterpreter {
number(1)
) * number(2)
}("x" to 3.toByte())
}
val res2 = ByteRing.mstInRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol("x") - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}.compile()("x" to 3.toByte())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ByteRing, x to 3.toByte()),
mst.compile(ByteRing, x to 3.toByte())
)
}
@Test
fun realField() {
val res1 = DoubleField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0)
}
val res2 = DoubleField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0)
assertEquals(res1, res2)
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
)
}
@Test
fun complexField() {
val res1 = ComplexField.mstInField {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}("x" to 2.0.toComplex())
}
val res2 = ComplexField.mstInField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol("x") + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}.compile()("x" to 2.0.toComplex())
assertEquals(res1, res2)
assertEquals(
mst.interpret(ComplexField, x to 2.0.toComplex()),
mst.compile(ComplexField, x to 2.0.toComplex())
)
}
}

View File

@ -1,10 +1,11 @@
package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInExtendedField
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.mstInGroup
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.ast.MstField
import space.kscience.kmath.ast.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
@ -12,29 +13,29 @@ import kotlin.test.assertEquals
internal class TestAsmOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") }.compile()
val expression = MstGroup { -bindSymbol("x") }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = DoubleField.mstInGroup { -bindSymbol("x") + number(1.0) }.compile()
val expression = MstGroup { -bindSymbol("x") + number(1.0) }.compileToExpression(DoubleField)
val res = expression("x" to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = DoubleField.mstInField { bindSymbol("x") * 2 }("x" to 2.0)
val res = MstField { bindSymbol("x") * 2 }.compileToExpression(DoubleField)("x" to 2.0)
assertEquals(4.0, res)
}
@Test
fun testMultipleCalls() {
val e =
DoubleField.mstInExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compile()
MstExtendedField { sin(bindSymbol("x")).pow(4) - 6 * bindSymbol("x") / tanh(bindSymbol("x")) }
.compileToExpression(DoubleField)
val r = Random(0)
var s = 0.0
repeat(1000000) { s += e("x" to r.nextDouble()) }

View File

@ -1,53 +1,63 @@
package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInField
import space.kscience.kmath.ast.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = DoubleField.mstInField { unaryOperationFunction("+")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(2.0, expr("x" to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = DoubleField.mstInField { unaryOperationFunction("-")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr("x" to 2.0))
}
@Test
fun testAdd() {
val expr = DoubleField.mstInField { binaryOperationFunction("+")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("+")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}
@Test
fun testSine() {
val expr = DoubleField.mstInField { unaryOperationFunction("sin")(bindSymbol("x")) }.compile()
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol("x")) }.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 0.0))
}
@Test
fun testMinus() {
val expr = DoubleField.mstInField { binaryOperationFunction("-")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr("x" to 2.0))
}
@Test
fun testDivide() {
val expr = DoubleField.mstInField { binaryOperationFunction("/")(bindSymbol("x"), bindSymbol("x")) }.compile()
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol("x"),
bindSymbol("x"))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr("x" to 2.0))
}
@Test
fun testPower() {
val expr = DoubleField
.mstInField { binaryOperationFunction("pow")(bindSymbol("x"), number(2)) }
.compile()
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol("x"), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr("x" to 2.0))
}

View File

@ -1,8 +1,9 @@
package space.kscience.kmath.asm
import space.kscience.kmath.ast.mstInRing
import space.kscience.kmath.ast.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
@ -10,13 +11,13 @@ import kotlin.test.assertFailsWith
internal class TestAsmVariables {
@Test
fun testVariable() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr("x" to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = ByteRing.mstInRing { bindSymbol("x") }.compile()
val expr = MstRing { bindSymbol("x") }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
}

View File

@ -2,9 +2,9 @@ package space.kscience.kmath.ast
import space.kscience.kmath.complex.Complex
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
@ -18,7 +18,7 @@ internal class ParserTest {
@Test
fun `evaluate MSTExpression`() {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
val res = MstField.invoke { number(2) + number(2) * (number(2) + number(2)) }.interpret(ComplexField)
assertEquals(Complex(10.0, 0.0), res)
}
@ -40,7 +40,7 @@ internal class ParserTest {
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun bindSymbol(value: String): String = value
override fun bindSymbolOrNull(value: String): String = value
override fun unaryOperationFunction(operation: String): (arg: String) -> String {
throw NotImplementedError()

View File

@ -0,0 +1,95 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.MST.Numeric
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
import kotlin.test.Test
internal class TestFeatures {
@Test
fun printSymbolic() = testLatex("x", "x")
@Test
fun printNumeric() {
val num = object : Number() {
override fun toByte(): Byte = throw UnsupportedOperationException()
override fun toChar(): Char = throw UnsupportedOperationException()
override fun toDouble(): Double = throw UnsupportedOperationException()
override fun toFloat(): Float = throw UnsupportedOperationException()
override fun toInt(): Int = throw UnsupportedOperationException()
override fun toLong(): Long = throw UnsupportedOperationException()
override fun toShort(): Short = throw UnsupportedOperationException()
override fun toString(): String = "foo"
}
testLatex(Numeric(num), "foo")
}
@Test
fun prettyPrintFloats() {
testLatex(Numeric(Double.NaN), "NaN")
testLatex(Numeric(Double.POSITIVE_INFINITY), "\\infty")
testLatex(Numeric(Double.NEGATIVE_INFINITY), "-\\infty")
testLatex(Numeric(1.0), "1")
testLatex(Numeric(-1.0), "-1")
testLatex(Numeric(1.42), "1.42")
testLatex(Numeric(-1.42), "-1.42")
testLatex(Numeric(1.1e10), "1.1\\times10^{10}")
testLatex(Numeric(1.1e-10), "1.1\\times10^{-10}")
testLatex(Numeric(-1.1e-10), "-1.1\\times10^{-10}")
testLatex(Numeric(-1.1e10), "-1.1\\times10^{10}")
}
@Test
fun prettyPrintIntegers() {
testLatex(Numeric(42), "42")
testLatex(Numeric(-42), "-42")
}
@Test
fun prettyPrintPi() {
testLatex("pi", "\\pi")
}
@Test
fun binaryPlus() = testLatex("2+2", "2+2")
@Test
fun binaryMinus() = testLatex("2-2", "2-2")
@Test
fun fraction() = testLatex("2/2", "\\frac{2}{2}")
@Test
fun binaryOperator() = testLatex("f(x, y)", "\\operatorname{f}\\left(x,y\\right)")
@Test
fun unaryOperator() = testLatex("f(x)", "\\operatorname{f}\\,\\left(x\\right)")
@Test
fun power() = testLatex("x^y", "x^{y}")
@Test
fun squareRoot() = testLatex("sqrt(x)", "\\sqrt{x}")
@Test
fun exponential() = testLatex("exp(x)", "e^{x}")
@Test
fun multiplication() = testLatex("x*1", "x\\times1")
@Test
fun inverseTrigonometry() {
testLatex("asin(x)", "\\operatorname{sin}^{-1}\\,\\left(x\\right)")
testLatex("asinh(x)", "\\operatorname{sinh}^{-1}\\,\\left(x\\right)")
testLatex("acos(x)", "\\operatorname{cos}^{-1}\\,\\left(x\\right)")
testLatex("acosh(x)", "\\operatorname{cosh}^{-1}\\,\\left(x\\right)")
testLatex("atan(x)", "\\operatorname{tan}^{-1}\\,\\left(x\\right)")
testLatex("atanh(x)", "\\operatorname{tanh}^{-1}\\,\\left(x\\right)")
}
// @Test
// fun unaryPlus() {
// testLatex("+1", "+1")
// testLatex("+1", "++1")
// }
}

View File

@ -0,0 +1,68 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
import space.kscience.kmath.operations.GroupOperations
import kotlin.test.Test
internal class TestLatex {
@Test
fun number() = testLatex("42", "42")
@Test
fun symbol() = testLatex("x", "x")
@Test
fun operatorName() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
@Test
fun specialSymbol() {
testLatex(MST.Numeric(Double.POSITIVE_INFINITY), "\\infty")
testLatex("pi", "\\pi")
}
@Test
fun operand() {
testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
testLatex("1+1", "1+1")
}
@Test
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
@Test
fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1")
@Test
fun unaryMinus() = testLatex("-x", "-x")
@Test
fun radical() = testLatex("sqrt(x)", "\\sqrt{x}")
@Test
fun superscript() = testLatex("x^y", "x^{y}")
@Test
fun subscript() = testLatex(SubscriptSyntax("", SymbolSyntax("x"), NumberSyntax("123")), "x_{123}")
@Test
fun binaryOperator() = testLatex("f(x, y)", "\\operatorname{f}\\left(x,y\\right)")
@Test
fun binaryPlus() = testLatex("x+x", "x+x")
@Test
fun binaryMinus() = testLatex("x-x", "x-x")
@Test
fun fraction() = testLatex("x/x", "\\frac{x}{x}")
@Test
fun radicalWithIndex() = testLatex(RadicalWithIndexSyntax("", SymbolSyntax("x"), SymbolSyntax("y")), "\\sqrt[x]{y}")
@Test
fun multiplication() {
testLatex("x*1", "x\\times1")
testLatex("1*x", "1\\,x")
}
}

View File

@ -0,0 +1,87 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.rendering.TestUtils.testMathML
import space.kscience.kmath.operations.GroupOperations
import kotlin.test.Test
internal class TestMathML {
@Test
fun number() = testMathML("42", "<mn>42</mn>")
@Test
fun symbol() = testMathML("x", "<mi>x</mi>")
@Test
fun operatorName() = testMathML(
"sin(1)",
"<mo>sin</mo><mspace width=\"0.167em\"></mspace><mfenced open=\"(\" close=\")\" separators=\"\"><mn>1</mn></mfenced>",
)
@Test
fun specialSymbol() {
testMathML(MST.Numeric(Double.POSITIVE_INFINITY), "<mo>&infin;</mo>")
testMathML("pi", "<mo>&pi;</mo>")
}
@Test
fun operand() {
testMathML(
"sin(1)",
"<mo>sin</mo><mspace width=\"0.167em\"></mspace><mfenced open=\"(\" close=\")\" separators=\"\"><mn>1</mn></mfenced>",
)
testMathML("1+1", "<mn>1</mn><mo>+</mo><mn>1</mn>")
}
@Test
fun unaryOperator() = testMathML(
"sin(1)",
"<mo>sin</mo><mspace width=\"0.167em\"></mspace><mfenced open=\"(\" close=\")\" separators=\"\"><mn>1</mn></mfenced>",
)
@Test
fun unaryPlus() =
testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
@Test
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")
@Test
fun radical() = testMathML("sqrt(x)", "<msqrt><mi>x</mi></msqrt>")
@Test
fun superscript() = testMathML("x^y", "<msup><mrow><mi>x</mi></mrow><mrow><mi>y</mi></mrow></msup>")
@Test
fun subscript() = testMathML(
SubscriptSyntax("", SymbolSyntax("x"), NumberSyntax("123")),
"<msub><mrow><mi>x</mi></mrow><mrow><mn>123</mn></mrow></msub>",
)
@Test
fun binaryOperator() = testMathML(
"f(x, y)",
"<mo>f</mo><mfenced open=\"(\" close=\")\" separators=\"\"><mi>x</mi><mo>,</mo><mi>y</mi></mfenced>",
)
@Test
fun binaryPlus() = testMathML("x+x", "<mi>x</mi><mo>+</mo><mi>x</mi>")
@Test
fun binaryMinus() = testMathML("x-x", "<mi>x</mi><mo>-</mo><mi>x</mi>")
@Test
fun fraction() = testMathML("x/x", "<mfrac><mrow><mi>x</mi></mrow><mrow><mi>x</mi></mrow></mfrac>")
@Test
fun radicalWithIndex() =
testMathML(RadicalWithIndexSyntax("", SymbolSyntax("x"), SymbolSyntax("y")),
"<mroot><mrow><mi>y</mi></mrow><mrow><mi>x</mi></mrow></mroot>")
@Test
fun multiplication() {
testMathML("x*1", "<mi>x</mi><mo>&times;</mo><mn>1</mn>")
testMathML("1*x", "<mn>1</mn><mspace width=\"0.167em\"></mspace><mi>x</mi>")
}
}

View File

@ -0,0 +1,28 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
import kotlin.test.Test
internal class TestStages {
@Test
fun betterMultiplication() {
testLatex("a*1", "a\\times1")
testLatex("1*(2/3)", "1\\times\\left(\\frac{2}{3}\\right)")
testLatex("1*1", "1\\times1")
testLatex("2e10", "2\\times10^{10}")
testLatex("2*x", "2\\,x")
testLatex("2*(x+1)", "2\\,\\left(x+1\\right)")
testLatex("x*y", "x\\,y")
}
@Test
fun parentheses() {
testLatex("(x+1)", "x+1")
testLatex("x*x*x", "x\\,x\\,x")
testLatex("(x+x)*x", "\\left(x+x\\right)\\,x")
testLatex("x+x*x", "x+x\\,x")
testLatex("x+x^x*x+x", "x+x^{x}\\,x+x")
testLatex("(x+x)^x+x*x", "\\left(x+x\\right)^{x}+x\\,x")
testLatex("x^(x+x)", "x^{x+x}")
}
}

View File

@ -0,0 +1,41 @@
package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.parseMath
import kotlin.test.assertEquals
internal object TestUtils {
private fun mathSyntax(mst: MST) = FeaturedMathRendererWithPostProcess.Default.render(mst)
private fun latex(mst: MST) = LatexSyntaxRenderer.renderWithStringBuilder(mathSyntax(mst))
private fun mathML(mst: MST) = MathMLSyntaxRenderer.renderWithStringBuilder(mathSyntax(mst))
internal fun testLatex(mst: MST, expectedLatex: String) = assertEquals(
expected = expectedLatex,
actual = latex(mst),
)
internal fun testLatex(expression: String, expectedLatex: String) = assertEquals(
expected = expectedLatex,
actual = latex(expression.parseMath()),
)
internal fun testLatex(expression: MathSyntax, expectedLatex: String) = assertEquals(
expected = expectedLatex,
actual = LatexSyntaxRenderer.renderWithStringBuilder(expression),
)
internal fun testMathML(mst: MST, expectedMathML: String) = assertEquals(
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
actual = mathML(mst),
)
internal fun testMathML(expression: String, expectedMathML: String) = assertEquals(
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
actual = mathML(expression.parseMath()),
)
internal fun testMathML(expression: MathSyntax, expectedMathML: String) = assertEquals(
expected = "<math xmlns=\"http://www.w3.org/1998/Math/MathML\"><mrow>$expectedMathML</mrow></math>",
actual = MathMLSyntaxRenderer.renderWithStringBuilder(expression),
)
}

View File

@ -2,7 +2,6 @@ package space.kscience.kmath.commons.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ExtendedField
@ -25,7 +24,7 @@ public class DerivativeStructureField(
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
override fun number(value: Number): DerivativeStructure = const(value.toDouble())
public override fun number(value: Number): DerivativeStructure = const(value.toDouble())
/**
* A class that implements both [DerivativeStructure] and a [Symbol]
@ -36,10 +35,10 @@ public class DerivativeStructureField(
symbol: Symbol,
value: Double,
) : DerivativeStructure(size, order, index, value), Symbol {
override val identity: String = symbol.identity
override fun toString(): String = identity
override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
override fun hashCode(): Int = identity.hashCode()
public override val identity: String = symbol.identity
public override fun toString(): String = identity
public override fun equals(other: Any?): Boolean = this.identity == (other as? Symbol)?.identity
public override fun hashCode(): Int = identity.hashCode()
}
/**
@ -49,13 +48,13 @@ public class DerivativeStructureField(
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
}.toMap()
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
public override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
public override fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
public override fun bindSymbolOrNull(value: String): DerivativeStructureSymbol? = variables[value]
public override fun bindSymbol(value: String): DerivativeStructureSymbol = variables.getValue(value)
public fun bind(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
override fun bindSymbol(value: String): DerivativeStructureSymbol = bind(StringSymbol(value))
public fun bindSymbolOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
public fun bindSymbol(symbol: Symbol): DerivativeStructureSymbol = variables.getValue(symbol.identity)
public fun DerivativeStructure.derivative(symbols: List<Symbol>): Double {
require(symbols.size <= order) { "The order of derivative ${symbols.size} exceeds computed order $order" }
@ -65,7 +64,7 @@ public class DerivativeStructureField(
public fun DerivativeStructure.derivative(vararg symbols: Symbol): Double = derivative(symbols.toList())
override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
public override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
public override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
@ -108,7 +107,6 @@ public class DerivativeStructureField(
}
}
/**
* A constructs that creates a derivative structure with required order on-demand
*/

View File

@ -27,7 +27,7 @@ internal class AutoDiffTest {
@Test
fun derivativeStructureFieldTest() {
diff(2, x to 1.0, y to 1.0) {
val x = bind(x)//by binding()
val x = bindSymbol(x)//by binding()
val y = bindSymbol("y")
val z = x * (-sin(x * y) + y) + 2.0
println(z.derivative(x))

View File

@ -2,10 +2,10 @@ package space.kscience.kmath.commons.optimization
import kotlinx.coroutines.runBlocking
import space.kscience.kmath.commons.expressions.DerivativeStructureExpression
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.optimization.FunctionOptimization
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.distributions.NormalDistribution
import kotlin.math.pow
import kotlin.test.Test
@ -14,7 +14,8 @@ internal class OptimizeTest {
val y by symbol
val normal = DerivativeStructureExpression {
exp(-bind(x).pow(2) / 2) + exp(-bind(y).pow(2) / 2)
exp(-bindSymbol(x).pow(2) / 2) + exp(-bindSymbol(y)
.pow(2) / 2)
}
@Test
@ -58,7 +59,7 @@ internal class OptimizeTest {
val chi2 = FunctionOptimization.chiSquared(x, y, yErr) { x1 ->
val cWithDefault = bindSymbolOrNull(c) ?: one
bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault
bindSymbol(a) * x1.pow(2) + bindSymbol(b) * x1 + cWithDefault
}
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)

View File

@ -8,7 +8,7 @@ Complex and hypercomplex number systems in KMath.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-complex:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -19,7 +19,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-complex:0.3.0-dev-3'
implementation 'space.kscience:kmath-complex:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -31,6 +31,6 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-complex:0.3.0-dev-3")
implementation("space.kscience:kmath-complex:0.3.0-dev-4")
}
```

View File

@ -121,8 +121,8 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, Num
/**
* Adds complex number to real one.
*
* @receiver the addend.
* @param c the augend.
* @receiver the augend.
* @param c the addend.
* @return the sum.
*/
public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
@ -139,8 +139,8 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, Num
/**
* Adds real number to complex one.
*
* @receiver the addend.
* @param d the augend.
* @receiver the augend.
* @param d the addend.
* @return the sum.
*/
public operator fun Complex.plus(d: Double): Complex = d + this
@ -165,8 +165,7 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, Num
public override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
public override fun bindSymbol(value: String): Complex =
if (value == "i") i else super<ExtendedField>.bindSymbol(value)
public override fun bindSymbolOrNull(value: String): Complex? = if (value == "i") i else null
}
/**

View File

@ -22,10 +22,10 @@ public class ComplexFieldND(
NumbersAddOperations<StructureND<Complex>>,
ExtendedField<StructureND<Complex>> {
override val zero: BufferND<Complex> by lazy { produce { zero } }
override val one: BufferND<Complex> by lazy { produce { one } }
public override val zero: BufferND<Complex> by lazy { produce { zero } }
public override val one: BufferND<Complex> by lazy { produce { one } }
override fun number(value: Number): BufferND<Complex> {
public override fun number(value: Number): BufferND<Complex> {
val d = value.toComplex() // minimize conversions
return produce { d }
}
@ -76,25 +76,25 @@ public class ComplexFieldND(
// return BufferedNDFieldElement(this, buffer)
// }
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) }
public override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) }
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) }
public override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) }
public override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) }
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) }
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) }
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) }
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) }
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) }
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) }
public override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) }
public override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) }
public override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) }
public override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) }
public override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) }
public override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) }
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) }
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) }
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) }
public override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) }
public override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) }
public override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) }
public override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) }
public override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) }
public override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) }
}

View File

@ -165,11 +165,11 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
public override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
public override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg)
public override fun bindSymbol(value: String): Quaternion = when (value) {
public override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {
"i" -> i
"j" -> j
"k" -> k
else -> super<Field>.bindSymbol(value)
else -> null
}
override fun number(value: Number): Quaternion = value.toQuaternion()

View File

@ -1,9 +1,9 @@
package space.kscience.kmath.complex
import space.kscience.kmath.expressions.FunctionalExpressionField
import space.kscience.kmath.expressions.bindSymbol
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.bindSymbol
import kotlin.test.Test
import kotlin.test.assertEquals

View File

@ -15,7 +15,7 @@ performance calculations to code generation.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-core:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -26,7 +26,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-core:0.3.0-dev-3'
implementation 'space.kscience:kmath-core:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -38,6 +38,6 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-core:0.3.0-dev-3")
implementation("space.kscience:kmath-core:0.3.0-dev-4")
}
```

View File

@ -50,8 +50,6 @@ public abstract interface class space/kscience/kmath/expressions/Expression {
}
public abstract interface class space/kscience/kmath/expressions/ExpressionAlgebra : space/kscience/kmath/operations/Algebra {
public abstract fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public abstract fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object;
public abstract fun const (Ljava/lang/Object;)Ljava/lang/Object;
}
@ -59,6 +57,7 @@ public final class space/kscience/kmath/expressions/ExpressionAlgebra$DefaultImp
public static fun binaryOperation (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperationFunction (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
@ -71,7 +70,6 @@ public final class space/kscience/kmath/expressions/ExpressionBuildersKt {
}
public final class space/kscience/kmath/expressions/ExpressionKt {
public static final fun bindSymbol (Lspace/kscience/kmath/expressions/ExpressionAlgebra;Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object;
public static final fun binding (Lspace/kscience/kmath/expressions/ExpressionAlgebra;)Lkotlin/properties/ReadOnlyProperty;
public static final fun callByString (Lspace/kscience/kmath/expressions/Expression;[Lkotlin/Pair;)Ljava/lang/Object;
public static final fun callBySymbol (Lspace/kscience/kmath/expressions/Expression;[Lkotlin/Pair;)Ljava/lang/Object;
@ -91,8 +89,8 @@ public abstract class space/kscience/kmath/expressions/FunctionalExpressionAlgeb
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression;
public synthetic fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object;
public fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/expressions/Expression;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression;
public synthetic fun const (Ljava/lang/Object;)Ljava/lang/Object;
public fun const (Ljava/lang/Object;)Lspace/kscience/kmath/expressions/Expression;
public final fun getAlgebra ()Lspace/kscience/kmath/operations/Algebra;
@ -123,6 +121,8 @@ public class space/kscience/kmath/expressions/FunctionalExpressionExtendedField
public synthetic fun atanh (Ljava/lang/Object;)Ljava/lang/Object;
public fun atanh (Lspace/kscience/kmath/expressions/Expression;)Lspace/kscience/kmath/expressions/Expression;
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
public fun cos (Lspace/kscience/kmath/expressions/Expression;)Lspace/kscience/kmath/expressions/Expression;
public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object;
@ -154,6 +154,8 @@ public class space/kscience/kmath/expressions/FunctionalExpressionExtendedField
public class space/kscience/kmath/expressions/FunctionalExpressionField : space/kscience/kmath/expressions/FunctionalExpressionRing, space/kscience/kmath/operations/Field, space/kscience/kmath/operations/ScaleOperations {
public fun <init> (Lspace/kscience/kmath/operations/Field;)V
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/expressions/Expression;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public final fun div (Ljava/lang/Object;Lspace/kscience/kmath/expressions/Expression;)Lspace/kscience/kmath/expressions/Expression;
@ -237,6 +239,8 @@ public final class space/kscience/kmath/expressions/SimpleAutoDiffExtendedField
public fun atan (Lspace/kscience/kmath/expressions/AutoDiffValue;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun atanh (Ljava/lang/Object;)Ljava/lang/Object;
public fun atanh (Lspace/kscience/kmath/expressions/AutoDiffValue;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
public fun cos (Lspace/kscience/kmath/expressions/AutoDiffValue;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object;
@ -278,8 +282,8 @@ public class space/kscience/kmath/expressions/SimpleAutoDiffField : space/kscien
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object;
public fun bindSymbolOrNull (Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public synthetic fun const (Ljava/lang/Object;)Ljava/lang/Object;
public fun const (Ljava/lang/Object;)Lspace/kscience/kmath/expressions/AutoDiffValue;
public final fun const (Lkotlin/jvm/functions/Function1;)Lspace/kscience/kmath/expressions/AutoDiffValue;
@ -710,6 +714,8 @@ public final class space/kscience/kmath/nd/BufferND : space/kscience/kmath/nd/St
public class space/kscience/kmath/nd/BufferedFieldND : space/kscience/kmath/nd/BufferedRingND, space/kscience/kmath/nd/FieldND {
public fun <init> ([ILspace/kscience/kmath/operations/Field;Lkotlin/jvm/functions/Function2;)V
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public fun div (Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
@ -743,6 +749,8 @@ public class space/kscience/kmath/nd/BufferedGroupND : space/kscience/kmath/nd/B
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/BufferND;
public synthetic fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND;
public fun getBuffer (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/structures/Buffer;
@ -821,6 +829,8 @@ public final class space/kscience/kmath/nd/DoubleFieldND : space/kscience/kmath/
public fun atan (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/BufferND;
public synthetic fun atanh (Ljava/lang/Object;)Ljava/lang/Object;
public fun atanh (Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/BufferND;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/BufferND;
public synthetic fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
@ -890,6 +900,7 @@ public final class space/kscience/kmath/nd/FieldND$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun binaryOperationFunction (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public static fun div (Lspace/kscience/kmath/nd/FieldND;Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun div (Lspace/kscience/kmath/nd/FieldND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Number;)Lspace/kscience/kmath/nd/StructureND;
public static fun div (Lspace/kscience/kmath/nd/FieldND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Object;)Lspace/kscience/kmath/nd/StructureND;
@ -935,6 +946,7 @@ public final class space/kscience/kmath/nd/GroupND$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun binaryOperationFunction (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public static fun invoke (Lspace/kscience/kmath/nd/GroupND;Lkotlin/jvm/functions/Function1;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun minus (Lspace/kscience/kmath/nd/GroupND;Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun minus (Lspace/kscience/kmath/nd/GroupND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Object;)Lspace/kscience/kmath/nd/StructureND;
@ -996,6 +1008,7 @@ public final class space/kscience/kmath/nd/RingND$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun binaryOperationFunction (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/nd/RingND;Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public static fun invoke (Lspace/kscience/kmath/nd/RingND;Lkotlin/jvm/functions/Function1;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun minus (Lspace/kscience/kmath/nd/RingND;Ljava/lang/Object;Lspace/kscience/kmath/nd/StructureND;)Lspace/kscience/kmath/nd/StructureND;
public static fun minus (Lspace/kscience/kmath/nd/RingND;Lspace/kscience/kmath/nd/StructureND;Ljava/lang/Object;)Lspace/kscience/kmath/nd/StructureND;
@ -1020,6 +1033,8 @@ public final class space/kscience/kmath/nd/ShapeMismatchException : java/lang/Ru
public final class space/kscience/kmath/nd/ShortRingND : space/kscience/kmath/nd/BufferedRingND, space/kscience/kmath/operations/NumbersAddOperations {
public fun <init> ([I)V
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/nd/StructureND;
public synthetic fun getOne ()Ljava/lang/Object;
public fun getOne ()Lspace/kscience/kmath/nd/BufferND;
public synthetic fun getZero ()Ljava/lang/Object;
@ -1146,6 +1161,7 @@ public abstract interface class space/kscience/kmath/operations/Algebra {
public abstract fun binaryOperation (Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public abstract fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public abstract fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public abstract fun unaryOperation (Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun unaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
@ -1154,6 +1170,7 @@ public final class space/kscience/kmath/operations/Algebra$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/Algebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
@ -1184,6 +1201,7 @@ public final class space/kscience/kmath/operations/AlgebraExtensionsKt {
public final class space/kscience/kmath/operations/AlgebraKt {
public static final fun bindSymbol (Lspace/kscience/kmath/operations/Algebra;Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object;
public static final fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Algebra;Lspace/kscience/kmath/misc/Symbol;)Ljava/lang/Object;
public static final fun invoke (Lspace/kscience/kmath/operations/Algebra;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
}
@ -1229,6 +1247,8 @@ public final class space/kscience/kmath/operations/BigIntField : space/kscience/
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/operations/BigInt;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/operations/BigInt;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public fun div (Lspace/kscience/kmath/operations/BigInt;Ljava/lang/Number;)Lspace/kscience/kmath/operations/BigInt;
@ -1302,6 +1322,8 @@ public final class space/kscience/kmath/operations/ByteRing : space/kscience/kma
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Byte;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Byte;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun getOne ()Ljava/lang/Byte;
public synthetic fun getOne ()Ljava/lang/Object;
public fun getZero ()Ljava/lang/Byte;
@ -1354,6 +1376,8 @@ public final class space/kscience/kmath/operations/DoubleField : space/kscience/
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Double;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Double;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun cos (D)Ljava/lang/Double;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
public fun cosh (D)Ljava/lang/Double;
@ -1454,6 +1478,7 @@ public final class space/kscience/kmath/operations/ExponentialOperations$Default
public static fun binaryOperation (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/ExponentialOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
@ -1462,6 +1487,7 @@ public abstract interface class space/kscience/kmath/operations/ExtendedField :
public abstract fun acosh (Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun asinh (Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun atanh (Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public abstract fun cosh (Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun rightSideNumberOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public abstract fun sinh (Ljava/lang/Object;)Ljava/lang/Object;
@ -1475,6 +1501,7 @@ public final class space/kscience/kmath/operations/ExtendedField$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/String;)Ljava/lang/Object;
public static fun cosh (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/Object;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/ExtendedField;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
@ -1508,6 +1535,7 @@ public final class space/kscience/kmath/operations/ExtendedFieldOperations$Defau
public static fun binaryOperation (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun minus (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun plus (Lspace/kscience/kmath/operations/ExtendedFieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
@ -1529,6 +1557,7 @@ public final class space/kscience/kmath/operations/Field$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/Field;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/Field;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun leftSideNumberOperation (Lspace/kscience/kmath/operations/Field;Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object;
@ -1562,6 +1591,7 @@ public final class space/kscience/kmath/operations/FieldOperations$DefaultImpls
public static fun binaryOperation (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun minus (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun plus (Lspace/kscience/kmath/operations/FieldOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
@ -1592,6 +1622,8 @@ public final class space/kscience/kmath/operations/FloatField : space/kscience/k
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Float;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Float;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun cos (F)Ljava/lang/Float;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
public fun cosh (F)Ljava/lang/Float;
@ -1665,6 +1697,7 @@ public final class space/kscience/kmath/operations/Group$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;)Ljava/lang/Object;
public static fun minus (Lspace/kscience/kmath/operations/Group;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun plus (Lspace/kscience/kmath/operations/Group;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/operations/Group;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
@ -1694,6 +1727,7 @@ public final class space/kscience/kmath/operations/GroupOperations$DefaultImpls
public static fun binaryOperation (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun minus (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun plus (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/operations/GroupOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
@ -1710,6 +1744,8 @@ public final class space/kscience/kmath/operations/IntRing : space/kscience/kmat
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Integer;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Integer;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun getOne ()Ljava/lang/Integer;
public synthetic fun getOne ()Ljava/lang/Object;
public fun getZero ()Ljava/lang/Integer;
@ -1760,6 +1796,8 @@ public abstract class space/kscience/kmath/operations/JBigDecimalFieldBase : spa
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Ljava/math/BigDecimal;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/math/BigDecimal;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public synthetic fun div (Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public fun div (Ljava/math/BigDecimal;Ljava/lang/Number;)Ljava/math/BigDecimal;
@ -1816,6 +1854,8 @@ public final class space/kscience/kmath/operations/JBigIntegerField : space/ksci
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Ljava/math/BigInteger;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/math/BigInteger;
public synthetic fun getOne ()Ljava/lang/Object;
public fun getOne ()Ljava/math/BigInteger;
public synthetic fun getZero ()Ljava/lang/Object;
@ -1857,6 +1897,8 @@ public final class space/kscience/kmath/operations/LongRing : space/kscience/kma
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Long;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Long;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun getOne ()Ljava/lang/Long;
public synthetic fun getOne ()Ljava/lang/Object;
public fun getZero ()Ljava/lang/Long;
@ -1896,6 +1938,7 @@ public final class space/kscience/kmath/operations/NumbersAddOperations$DefaultI
public static fun binaryOperation (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun leftSideNumberOperation (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object;
public static fun leftSideNumberOperationFunction (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun minus (Lspace/kscience/kmath/operations/NumbersAddOperations;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object;
@ -1912,6 +1955,7 @@ public final class space/kscience/kmath/operations/NumbersAddOperations$DefaultI
}
public abstract interface class space/kscience/kmath/operations/NumericAlgebra : space/kscience/kmath/operations/Algebra {
public abstract fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public abstract fun leftSideNumberOperation (Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun leftSideNumberOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public abstract fun number (Ljava/lang/Number;)Ljava/lang/Object;
@ -1923,6 +1967,7 @@ public final class space/kscience/kmath/operations/NumericAlgebra$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Ljava/lang/Object;
public static fun leftSideNumberOperation (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object;
public static fun leftSideNumberOperationFunction (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun rightSideNumberOperation (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
@ -1931,6 +1976,11 @@ public final class space/kscience/kmath/operations/NumericAlgebra$DefaultImpls {
public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/NumericAlgebra;Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
public final class space/kscience/kmath/operations/NumericAlgebraKt {
public static final fun getE (Lspace/kscience/kmath/operations/NumericAlgebra;)Ljava/lang/Object;
public static final fun getPi (Lspace/kscience/kmath/operations/NumericAlgebra;)Ljava/lang/Object;
}
public final class space/kscience/kmath/operations/OptionalOperationsKt {
}
@ -1952,6 +2002,7 @@ public final class space/kscience/kmath/operations/PowerOperations$DefaultImpls
public static fun binaryOperation (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun pow (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public static fun sqrt (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/operations/PowerOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
@ -1966,6 +2017,7 @@ public final class space/kscience/kmath/operations/Ring$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/Ring;Ljava/lang/String;)Ljava/lang/Object;
public static fun minus (Lspace/kscience/kmath/operations/Ring;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun plus (Lspace/kscience/kmath/operations/Ring;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun times (Lspace/kscience/kmath/operations/Ring;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
@ -1990,6 +2042,7 @@ public final class space/kscience/kmath/operations/RingOperations$DefaultImpls {
public static fun binaryOperation (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun minus (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun plus (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun times (Lspace/kscience/kmath/operations/RingOperations;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
@ -2009,6 +2062,7 @@ public final class space/kscience/kmath/operations/ScaleOperations$DefaultImpls
public static fun binaryOperation (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun div (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
public static fun times (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/Number;Ljava/lang/Object;)Ljava/lang/Object;
public static fun times (Lspace/kscience/kmath/operations/ScaleOperations;Ljava/lang/Object;Ljava/lang/Number;)Ljava/lang/Object;
@ -2025,6 +2079,8 @@ public final class space/kscience/kmath/operations/ShortRing : space/kscience/km
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Ljava/lang/Short;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Short;
public synthetic fun getOne ()Ljava/lang/Object;
public fun getOne ()Ljava/lang/Short;
public synthetic fun getZero ()Ljava/lang/Object;
@ -2085,6 +2141,7 @@ public final class space/kscience/kmath/operations/TrigonometricOperations$Defau
public static fun binaryOperation (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;
public static fun binaryOperationFunction (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public static fun bindSymbol (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun bindSymbolOrNull (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Ljava/lang/Object;
public static fun unaryOperation (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;Ljava/lang/Object;)Ljava/lang/Object;
public static fun unaryOperationFunction (Lspace/kscience/kmath/operations/TrigonometricOperations;Ljava/lang/String;)Lkotlin/jvm/functions/Function1;
}
@ -2175,6 +2232,8 @@ public final class space/kscience/kmath/structures/DoubleBufferField : space/ksc
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
public fun cos-Udx-57Q (Lspace/kscience/kmath/structures/Buffer;)[D
public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object;
@ -2260,6 +2319,8 @@ public final class space/kscience/kmath/structures/DoubleBufferFieldOperations :
public fun binaryOperationFunction (Ljava/lang/String;)Lkotlin/jvm/functions/Function2;
public synthetic fun bindSymbol (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbol (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer;
public synthetic fun bindSymbolOrNull (Ljava/lang/String;)Ljava/lang/Object;
public fun bindSymbolOrNull (Ljava/lang/String;)Lspace/kscience/kmath/structures/Buffer;
public synthetic fun cos (Ljava/lang/Object;)Ljava/lang/Object;
public fun cos-Udx-57Q (Lspace/kscience/kmath/structures/Buffer;)[D
public synthetic fun cosh (Ljava/lang/Object;)Ljava/lang/Object;

View File

@ -55,15 +55,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
* @param E type of the actual expression state
*/
public interface ExpressionAlgebra<in T, E> : Algebra<E> {
/**
* Bind a given [Symbol] to this context variable and produce context-specific object. Return null if symbol could not be bound in current context.
*/
public fun bindSymbolOrNull(symbol: Symbol): E?
/**
* Bind a string to a context using [StringSymbol]
*/
override fun bindSymbol(value: String): E = bindSymbol(StringSymbol(value))
/**
* A constant expression which does not depend on arguments
@ -71,15 +62,9 @@ public interface ExpressionAlgebra<in T, E> : Algebra<E> {
public fun const(value: T): E
}
/**
* Bind a given [Symbol] to this context variable and produce context-specific object.
*/
public fun <T, E> ExpressionAlgebra<T, E>.bindSymbol(symbol: Symbol): E =
bindSymbolOrNull(symbol) ?: error("Symbol $symbol could not be bound to $this")
/**
* Bind a symbol by name inside the [ExpressionAlgebra]
*/
public fun <T, E> ExpressionAlgebra<T, E>.binding(): ReadOnlyProperty<Any?, E> = ReadOnlyProperty { _, property ->
bindSymbol(StringSymbol(property.name)) ?: error("A variable with name ${property.name} does not exist")
bindSymbol(property.name) ?: error("A variable with name ${property.name} does not exist")
}

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.expressions
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.StringSymbol
import space.kscience.kmath.operations.*
/**
@ -19,8 +19,10 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
/**
* Builds an Expression to access a variable.
*/
public override fun bindSymbolOrNull(symbol: Symbol): Expression<T>? = Expression { arguments ->
arguments[symbol] ?: error("Argument not found: $symbol")
public override fun bindSymbolOrNull(value: String): Expression<T>? = Expression { arguments ->
algebra.bindSymbolOrNull(value)
?: arguments[StringSymbol(value)]
?: error("Symbol '$value' is not supported in $this")
}
/**
@ -49,7 +51,7 @@ public open class FunctionalExpressionGroup<T, A : Group<T>>(
) : FunctionalExpressionAlgebra<T, A>(algebra), Group<Expression<T>> {
public override val zero: Expression<T> get() = const(algebra.zero)
override fun Expression<T>.unaryMinus(): Expression<T> =
public override fun Expression<T>.unaryMinus(): Expression<T> =
unaryOperation(GroupOperations.MINUS_OPERATION, this)
/**
@ -101,8 +103,7 @@ public open class FunctionalExpressionRing<T, A : Ring<T>>(
public open class FunctionalExpressionField<T, A : Field<T>>(
algebra: A,
) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>,
ScaleOperations<Expression<T>> {
) : FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>, ScaleOperations<Expression<T>> {
/**
* Builds an Expression of division an expression by another one.
*/
@ -118,16 +119,21 @@ public open class FunctionalExpressionField<T, A : Field<T>>(
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionRing>.binaryOperationFunction(operation)
override fun scale(a: Expression<T>, value: Double): Expression<T> = algebra {
public override fun scale(a: Expression<T>, value: Double): Expression<T> = algebra {
Expression { args -> a(args) * value }
}
public override fun bindSymbolOrNull(value: String): Expression<T>? =
super<FunctionalExpressionRing>.bindSymbolOrNull(value)
}
public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
algebra: A,
) : FunctionalExpressionField<T, A>(algebra), ExtendedField<Expression<T>> {
public override fun number(value: Number): Expression<T> = const(algebra.number(value))
override fun number(value: Number): Expression<T> = const(algebra.number(value))
public override fun sqrt(arg: Expression<T>): Expression<T> =
unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
public override fun sin(arg: Expression<T>): Expression<T> =
unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg)
@ -158,6 +164,8 @@ public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
public override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
super<FunctionalExpressionField>.binaryOperationFunction(operation)
public override fun bindSymbol(value: String): Expression<T> = super<FunctionalExpressionField>.bindSymbol(value)
}
public inline fun <T, A : Group<T>> A.expressionInSpace(block: FunctionalExpressionGroup<T, A>.() -> Expression<T>): Expression<T> =

View File

@ -85,7 +85,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
override fun hashCode(): Int = identity.hashCode()
}
public override fun bindSymbolOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
override fun bindSymbolOrNull(value: String): AutoDiffValue<T>? = bindings[value]
private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
@ -337,9 +337,11 @@ public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
) : ExtendedField<AutoDiffValue<T>>, ScaleOperations<AutoDiffValue<T>>,
SimpleAutoDiffField<T, F>(context, bindings) {
override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
override fun bindSymbol(value: String): AutoDiffValue<T> = super<SimpleAutoDiffField>.bindSymbol(value)
override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> = a * number(value)
public override fun number(value: Number): AutoDiffValue<T> = const { number(value) }
public override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> = a * number(value)
// x ^ 2
public fun sqr(x: AutoDiffValue<T>): AutoDiffValue<T> =

View File

@ -120,8 +120,8 @@ public interface GroupND<T, S : Group<T>> : Group<StructureND<T>>, AlgebraND<T,
/**
* Element-wise addition.
*
* @param a the addend.
* @param b the augend.
* @param a the augend.
* @param b the addend.
* @return the sum.
*/
public override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
@ -141,8 +141,8 @@ public interface GroupND<T, S : Group<T>> : Group<StructureND<T>>, AlgebraND<T,
/**
* Adds an ND structure to an element of it.
*
* @receiver the addend.
* @param arg the augend.
* @receiver the augend.
* @param arg the addend.
* @return the sum.
*/
public operator fun StructureND<T>.plus(arg: T): StructureND<T> = this.map { value -> add(arg, value) }
@ -159,8 +159,8 @@ public interface GroupND<T, S : Group<T>> : Group<StructureND<T>>, AlgebraND<T,
/**
* Adds an element to ND structure of it.
*
* @receiver the addend.
* @param arg the augend.
* @receiver the augend.
* @param arg the addend.
* @return the sum.
*/
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(this@plus, value) }

View File

@ -17,15 +17,15 @@ public class DoubleFieldND(
ScaleOperations<StructureND<Double>>,
ExtendedField<StructureND<Double>> {
override val zero: BufferND<Double> by lazy { produce { zero } }
override val one: BufferND<Double> by lazy { produce { one } }
public override val zero: BufferND<Double> by lazy { produce { zero } }
public override val one: BufferND<Double> by lazy { produce { one } }
override fun number(value: Number): BufferND<Double> {
public override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions
return produce { d }
}
override val StructureND<Double>.buffer: DoubleBuffer
public override val StructureND<Double>.buffer: DoubleBuffer
get() = when {
!shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException(
this@DoubleFieldND.shape,
@ -36,7 +36,7 @@ public class DoubleFieldND(
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.map(
public override inline fun StructureND<Double>.map(
transform: DoubleField.(Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) }
@ -44,7 +44,7 @@ public class DoubleFieldND(
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
public override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset ->
val index = strides.index(offset)
DoubleField.initializer(index)
@ -53,7 +53,7 @@ public class DoubleFieldND(
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.mapIndexed(
public override inline fun StructureND<Double>.mapIndexed(
transform: DoubleField.(index: IntArray, Double) -> Double,
): BufferND<Double> = BufferND(
strides,
@ -65,7 +65,7 @@ public class DoubleFieldND(
})
@Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
public override inline fun combine(
a: StructureND<Double>,
b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
@ -76,27 +76,26 @@ public class DoubleFieldND(
return BufferND(strides, buffer)
}
override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
public override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
public override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
public override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
public override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
public override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
public override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
public override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
public override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
public override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
public override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
public override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
public override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
public override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
public override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
public override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
public override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
}
public fun AlgebraND.Companion.real(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)

View File

@ -23,10 +23,18 @@ public interface Algebra<T> {
*
* In case if algebra can't parse the string, this method must throw [kotlin.IllegalStateException].
*
* Returns `null` if symbol could not be bound to the context
*
* @param value the raw string.
* @return an object.
*/
public fun bindSymbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
public fun bindSymbolOrNull(value: String): T? = null
/**
* The same as [bindSymbolOrNull] but throws an error if symbol could not be bound
*/
public fun bindSymbol(value: String): T =
bindSymbolOrNull(value) ?: error("Symbol '$value' is not supported in $this")
/**
* Dynamically dispatches an unary operation with the certain name.
@ -91,7 +99,9 @@ public interface Algebra<T> {
binaryOperationFunction(operation)(left, right)
}
public fun <T : Any> Algebra<T>.bindSymbol(symbol: Symbol): T = bindSymbol(symbol.identity)
public fun <T> Algebra<T>.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity)
public fun <T> Algebra<T>.bindSymbol(symbol: Symbol): T = bindSymbol(symbol.identity)
/**
* Call a block with an [Algebra] as receiver.
@ -109,8 +119,8 @@ public interface GroupOperations<T> : Algebra<T> {
/**
* Addition of two elements.
*
* @param a the addend.
* @param b the augend.
* @param a the augend.
* @param b the addend.
* @return the sum.
*/
public fun add(a: T, b: T): T
@ -136,8 +146,8 @@ public interface GroupOperations<T> : Algebra<T> {
/**
* Addition of two elements.
*
* @receiver the addend.
* @param b the augend.
* @receiver the augend.
* @param b the addend.
* @return the sum.
*/
public operator fun T.plus(b: T): T = add(this, b)
@ -283,5 +293,5 @@ public interface FieldOperations<T> : RingOperations<T> {
* @param T the type of element of this field.
*/
public interface Field<T> : Ring<T>, FieldOperations<T>, ScaleOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = scale(one, value.toDouble())
}
public override fun number(value: Number): T = scale(one, value.toDouble())
}

View File

@ -46,7 +46,8 @@ public operator fun <T : AlgebraElement<T, S>, S : NumbersAddOperations<T>> T.mi
/**
* Adds element to this one.
*
* @param b the augend.
* @receiver the augend.
* @param b the addend.
* @return the sum.
*/
public operator fun <T : AlgebraElement<T, S>, S : Group<T>> T.plus(b: T): T =
@ -58,11 +59,11 @@ public operator fun <T : AlgebraElement<T, S>, S : Group<T>> T.plus(b: T): T =
//public operator fun <T : AlgebraElement<T, S>, S : Space<T>> Number.times(element: T): T =
// element.times(this)
/**
* Multiplies this element by another one.
*
* @param b the multiplicand.
* @receiver the multiplicand.
* @param b the multiplier.
* @return the product.
*/
public operator fun <T : AlgebraElement<T, R>, R : Ring<T>> T.times(b: T): T =

View File

@ -18,7 +18,8 @@ private typealias TBase = ULong
/**
* Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger).
*
* @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai)
* @author Robert Drynkin
* @author Peter Klimai
*/
@OptIn(UnstableKMathAPI::class)
public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOperations<BigInt> {

View File

@ -1,6 +1,8 @@
package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import kotlin.math.E
import kotlin.math.PI
/**
* An algebraic structure where elements can have numeric representation.
@ -79,8 +81,26 @@ public interface NumericAlgebra<T> : Algebra<T> {
*/
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
rightSideNumberOperationFunction(operation)(left, right)
public override fun bindSymbolOrNull(value: String): T? = when (value) {
"pi" -> number(PI)
"e" -> number(E)
else -> super.bindSymbolOrNull(value)
}
}
/**
* The &pi; mathematical constant.
*/
public val <T> NumericAlgebra<T>.pi: T
get() = bindSymbolOrNull("pi") ?: number(PI)
/**
* The *e* mathematical constant.
*/
public val <T> NumericAlgebra<T>.e: T
get() = number(E)
/**
* Scale by scalar operations
*/
@ -131,16 +151,16 @@ public interface NumbersAddOperations<T> : Group<T>, NumericAlgebra<T> {
/**
* Addition of element and scalar.
*
* @receiver the addend.
* @param b the augend.
* @receiver the augend.
* @param b the addend.
*/
public operator fun T.plus(b: Number): T = this + number(b)
/**
* Addition of scalar and element.
*
* @receiver the addend.
* @param b the augend.
* @receiver the augend.
* @param b the addend.
*/
public operator fun Number.plus(b: T): T = b + this

View File

@ -44,6 +44,12 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, Numeri
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2.0
public override fun bindSymbol(value: String): T = when (value) {
"pi" -> pi
"e" -> e
else -> super<ExtendedFieldOperations>.bindSymbol(value)
}
public override fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T =
when (operation) {
PowerOperations.POW_OPERATION -> ::power
@ -56,10 +62,10 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, Numeri
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
public override val zero: Double = 0.0
public override val one: Double = 1.0
public override inline val zero: Double get() = 0.0
public override inline val one: Double get() = 1.0
override fun number(value: Number): Double = value.toDouble()
public override fun number(value: Number): Double = value.toDouble()
public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double =
when (operation) {
@ -68,13 +74,11 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
}
public override inline fun add(a: Double, b: Double): Double = a + b
// public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
// override fun divide(a: Double, k: Number): Double = a / k.toDouble()
public override inline fun multiply(a: Double, b: Double): Double = a * b
public override inline fun divide(a: Double, b: Double): Double = a / b
override fun scale(a: Double, value: Double): Double = a * value
public override fun scale(a: Double, value: Double): Double = a * value
public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
@ -108,10 +112,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
public override val zero: Float = 0.0f
public override val one: Float = 1.0f
public override inline val zero: Float get() = 0.0f
public override inline val one: Float get() = 1.0f
override fun number(value: Number): Float = value.toFloat()
public override fun number(value: Number): Float = value.toFloat()
public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float =
when (operation) {
@ -120,7 +124,7 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
}
public override inline fun add(a: Float, b: Float): Float = a + b
override fun scale(a: Float, value: Double): Float = a * value.toFloat()
public override fun scale(a: Float, value: Double): Float = a * value.toFloat()
public override inline fun multiply(a: Float, b: Float): Float = a * b
@ -158,13 +162,13 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
public override val zero: Int
public override inline val zero: Int
get() = 0
public override val one: Int
public override inline val one: Int
get() = 1
override fun number(value: Number): Int = value.toInt()
public override fun number(value: Number): Int = value.toInt()
public override inline fun add(a: Int, b: Int): Int = a + b
public override inline fun multiply(a: Int, b: Int): Int = a * b
public override inline fun norm(arg: Int): Int = abs(arg)
@ -180,13 +184,13 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short> {
public override val zero: Short
public override inline val zero: Short
get() = 0
public override val one: Short
public override inline val one: Short
get() = 1
override fun number(value: Number): Short = value.toShort()
public override fun number(value: Number): Short = value.toShort()
public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
@ -202,13 +206,13 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
public override val zero: Byte
public override inline val zero: Byte
get() = 0
public override val one: Byte
public override inline val one: Byte
get() = 1
override fun number(value: Number): Byte = value.toByte()
public override fun number(value: Number): Byte = value.toByte()
public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
@ -224,13 +228,13 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
*/
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
public override val zero: Long
public override inline val zero: Long
get() = 0L
public override val one: Long
public override inline val one: Long
get() = 1L
override fun number(value: Number): Long = value.toLong()
public override fun number(value: Number): Long = value.toLong()
public override inline fun add(a: Long, b: Long): Long = a + b
public override inline fun multiply(a: Long, b: Long): Long = a * b
public override fun norm(arg: Long): Long = abs(arg)

View File

@ -3,6 +3,7 @@ package space.kscience.kmath.expressions
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.asBuffer
import kotlin.math.E

View File

@ -0,0 +1,50 @@
package space.kscience.kmath.chains
import space.kscience.kmath.structures.Buffer
public interface BufferChain<out T> : Chain<T> {
public suspend fun nextBuffer(size: Int): Buffer<T>
override suspend fun fork(): BufferChain<T>
}
/**
* A chain with blocking generator that could be used without suspension
*/
public interface BlockingChain<out T> : Chain<T> {
/**
* Get the next value without concurrency support. Not guaranteed to be thread safe.
*/
public fun nextBlocking(): T
override suspend fun next(): T = nextBlocking()
override suspend fun fork(): BlockingChain<T>
}
public interface BlockingBufferChain<out T> : BlockingChain<T>, BufferChain<T> {
public fun nextBufferBlocking(size: Int): Buffer<T>
public override fun nextBlocking(): T = nextBufferBlocking(1)[0]
public override suspend fun nextBuffer(size: Int): Buffer<T> = nextBufferBlocking(size)
override suspend fun fork(): BlockingBufferChain<T>
}
public suspend inline fun <reified T : Any> Chain<T>.nextBuffer(size: Int): Buffer<T> = if (this is BufferChain) {
nextBuffer(size)
} else {
Buffer.auto(size) { next() }
}
public inline fun <reified T : Any> BlockingChain<T>.nextBufferBlocking(
size: Int,
): Buffer<T> = if (this is BlockingBufferChain) {
nextBufferBlocking(size)
} else {
Buffer.auto(size) { nextBlocking() }
}

View File

@ -1,13 +1,27 @@
package space.kscience.kmath.chains
import space.kscience.kmath.structures.DoubleBuffer
/**
* Chunked, specialized chain for real values.
* Chunked, specialized chain for double values, which supports blocking [nextBlocking] operation
*/
public interface BlockingDoubleChain : Chain<Double> {
public override suspend fun next(): Double
public interface BlockingDoubleChain : BlockingBufferChain<Double> {
/**
* Returns an [DoubleArray] chunk of [size] values of [next].
*/
public suspend fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { next() }
public override fun nextBufferBlocking(size: Int): DoubleBuffer
override suspend fun fork(): BlockingDoubleChain
public companion object
}
public fun BlockingDoubleChain.map(transform: (Double) -> Double): BlockingDoubleChain = object : BlockingDoubleChain {
override fun nextBufferBlocking(size: Int): DoubleBuffer {
val block = this@map.nextBufferBlocking(size)
return DoubleBuffer(size) { transform(block[it]) }
}
override suspend fun fork(): BlockingDoubleChain = this@map.fork().map(transform)
}

View File

@ -1,9 +1,12 @@
package space.kscience.kmath.chains
import space.kscience.kmath.structures.IntBuffer
/**
* Performance optimized chain for integer values
*/
public interface BlockingIntChain : Chain<Int> {
public override suspend fun next(): Int
public suspend fun nextBlock(size: Int): IntArray = IntArray(size) { next() }
}
public interface BlockingIntChain : BlockingBufferChain<Int> {
override fun nextBufferBlocking(size: Int): IntBuffer
override suspend fun fork(): BlockingIntChain
}

View File

@ -24,20 +24,20 @@ import kotlinx.coroutines.sync.withLock
/**
* A not-necessary-Markov chain of some type
* @param R - the chain element type
* @param T - the chain element type
*/
public interface Chain<out R> : Flow<R> {
public interface Chain<out T> : Flow<T> {
/**
* Generate next value, changing state if needed
*/
public suspend fun next(): R
public suspend fun next(): T
/**
* Create a copy of current chain state. Consuming resulting chain does not affect initial chain
*/
public fun fork(): Chain<R>
public suspend fun fork(): Chain<T>
override suspend fun collect(collector: FlowCollector<R>): Unit =
override suspend fun collect(collector: FlowCollector<T>): Unit =
flow { while (true) emit(next()) }.collect(collector)
public companion object
@ -51,7 +51,7 @@ public fun <T> Sequence<T>.asChain(): Chain<T> = iterator().asChain()
*/
public class SimpleChain<out R>(private val gen: suspend () -> R) : Chain<R> {
public override suspend fun next(): R = gen()
public override fun fork(): Chain<R> = this
public override suspend fun fork(): Chain<R> = this
}
/**
@ -69,7 +69,7 @@ public class MarkovChain<out R : Any>(private val seed: suspend () -> R, private
newValue
}
public override fun fork(): Chain<R> = MarkovChain(seed = { value ?: seed() }, gen = gen)
public override suspend fun fork(): Chain<R> = MarkovChain(seed = { value ?: seed() }, gen = gen)
}
/**
@ -94,7 +94,7 @@ public class StatefulChain<S, out R>(
newValue
}
public override fun fork(): Chain<R> = StatefulChain(forkState(state), seed, forkState, gen)
public override suspend fun fork(): Chain<R> = StatefulChain(forkState(state), seed, forkState, gen)
}
/**
@ -102,7 +102,7 @@ public class StatefulChain<S, out R>(
*/
public class ConstantChain<out T>(public val value: T) : Chain<T> {
public override suspend fun next(): T = value
public override fun fork(): Chain<T> = this
public override suspend fun fork(): Chain<T> = this
}
/**
@ -111,7 +111,7 @@ public class ConstantChain<out T>(public val value: T) : Chain<T> {
*/
public fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = func(this@map.next())
override fun fork(): Chain<R> = this@map.fork().map(func)
override suspend fun fork(): Chain<R> = this@map.fork().map(func)
}
/**
@ -127,7 +127,7 @@ public fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain
return next
}
override fun fork(): Chain<T> = this@filter.fork().filter(block)
override suspend fun fork(): Chain<T> = this@filter.fork().filter(block)
}
/**
@ -135,7 +135,7 @@ public fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain
*/
public fun <T, R> Chain<T>.collect(mapper: suspend (Chain<T>) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = mapper(this@collect)
override fun fork(): Chain<R> = this@collect.fork().collect(mapper)
override suspend fun fork(): Chain<R> = this@collect.fork().collect(mapper)
}
public fun <T, S, R> Chain<T>.collectWithState(
@ -145,7 +145,7 @@ public fun <T, S, R> Chain<T>.collectWithState(
): Chain<R> = object : Chain<R> {
override suspend fun next(): R = state.mapper(this@collectWithState)
override fun fork(): Chain<R> =
override suspend fun fork(): Chain<R> =
this@collectWithState.fork().collectWithState(stateFork(state), stateFork, mapper)
}
@ -154,5 +154,5 @@ public fun <T, S, R> Chain<T>.collectWithState(
*/
public fun <T, U, R> Chain<T>.zip(other: Chain<U>, block: suspend (T, U) -> R): Chain<R> = object : Chain<R> {
override suspend fun next(): R = block(this@zip.next(), other.next())
override fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
override suspend fun fork(): Chain<R> = this@zip.fork().zip(other.fork(), block)
}

View File

@ -6,7 +6,6 @@ import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.asBuffer
/**
* Create a [Flow] from buffer
@ -50,7 +49,7 @@ public fun Flow<Double>.chunked(bufferSize: Int): Flow<DoubleBuffer> = flow {
if (this@chunked is BlockingDoubleChain) {
// performance optimization for blocking primitive chain
while (true) emit(nextBlock(bufferSize).asBuffer())
while (true) emit(nextBufferBlocking(bufferSize))
} else {
val array = DoubleArray(bufferSize)
var counter = 0

View File

@ -9,7 +9,7 @@ EJML based linear algebra implementation.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-ejml:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -20,7 +20,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-ejml:0.3.0-dev-3'
implementation 'space.kscience:kmath-ejml:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -32,6 +32,6 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-ejml:0.3.0-dev-3")
implementation("space.kscience:kmath-ejml:0.3.0-dev-4")
}
```

View File

@ -9,7 +9,7 @@ Specialization of KMath APIs for Double numbers.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-for-real:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -20,7 +20,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-for-real:0.3.0-dev-3'
implementation 'space.kscience:kmath-for-real:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -32,6 +32,6 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-for-real:0.3.0-dev-3")
implementation("space.kscience:kmath-for-real:0.3.0-dev-4")
}
```

View File

@ -15,4 +15,13 @@ class GridTest {
assertEquals(6, grid.size)
assertTrue { (grid - DoubleVector(0.0, 0.2, 0.4, 0.6, 0.8, 1.0)).norm < 1e-4 }
}
@Test
fun testIterateGrid(){
var res = 0.0
for(d in 0.0..1.0 step 0.2){
res = d
}
assertEquals(1.0, res)
}
}

View File

@ -10,7 +10,7 @@ Functions and interpolations.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-functions:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -21,7 +21,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-functions:0.3.0-dev-3'
implementation 'space.kscience:kmath-functions:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -33,6 +33,6 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-functions:0.3.0-dev-3")
implementation("space.kscience:kmath-functions:0.3.0-dev-4")
}
```

View File

@ -3,7 +3,7 @@ package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.SFun
import space.kscience.kmath.ast.MST
import space.kscience.kmath.ast.MstAlgebra
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.ast.interpret
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.NumericAlgebra
@ -18,38 +18,26 @@ import space.kscience.kmath.operations.NumericAlgebra
* @param A the [NumericAlgebra] of [T].
* @property expr the underlying [MstExpression].
*/
public inline class DifferentiableMstExpression<T: Number, A>(
public val expr: MstExpression<T, A>,
) : DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T> {
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
public val mst: MST,
) : DifferentiableExpression<T, DifferentiableMstExpression<T, A>> {
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
/**
* The [MstExpression.algebra] of [expr].
*/
public val algebra: A
get() = expr.algebra
/**
* The [MstExpression.mst] of [expr].
*/
public val mst: MST
get() = expr.mst
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::bindSymbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::bindSymbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* Wraps this [MstExpression] into [DifferentiableMstExpression].
* Wraps this [MST] into [DifferentiableMstExpression].
*/
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(this)
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(algebra, this)

View File

@ -1,9 +1,8 @@
package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.*
import space.kscience.kmath.asm.compile
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.ast.MstAlgebra
import space.kscience.kmath.ast.MstExpression
import space.kscience.kmath.ast.parseMath
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.operations.DoubleField
@ -43,8 +42,8 @@ internal class AdaptingTests {
fun simpleFunctionDerivative() {
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = MstExpression(DoubleField, quadratic.d(x).toMst()).compile()
val expectedDerivative = MstExpression(DoubleField, "2*x-4".parseMath()).compile()
val actualDerivative = quadratic.d(x).toMst().compileToExpression(DoubleField)
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
}
@ -52,12 +51,11 @@ internal class AdaptingTests {
fun moreComplexDerivative() {
val x = MstAlgebra.bindSymbol("x").toSVar<KMathNumber<Double, DoubleField>>()
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, DoubleField>>()
val actualDerivative = MstExpression(DoubleField, composition.d(x).toMst()).compile()
val actualDerivative = composition.d(x).toMst().compileToExpression(DoubleField)
val expectedDerivative =
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath().compileToExpression(DoubleField)
val expectedDerivative = MstExpression(
DoubleField,
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
).compile()
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
}

View File

@ -9,7 +9,7 @@ ND4J based implementations of KMath abstractions.
## Artifact:
The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-3`.
The Maven coordinates of this project are `space.kscience:kmath-nd4j:0.3.0-dev-4`.
**Gradle:**
```gradle
@ -20,7 +20,7 @@ repositories {
}
dependencies {
implementation 'space.kscience:kmath-nd4j:0.3.0-dev-3'
implementation 'space.kscience:kmath-nd4j:0.3.0-dev-4'
}
```
**Gradle Kotlin DSL:**
@ -32,7 +32,7 @@ repositories {
}
dependencies {
implementation("space.kscience:kmath-nd4j:0.3.0-dev-3")
implementation("space.kscience:kmath-nd4j:0.3.0-dev-4")
}
```

View File

@ -2,6 +2,10 @@ plugins {
id("ru.mipt.npm.gradle.mpp")
}
kscience{
useAtomic()
}
kotlin.sourceSets {
commonMain {
dependencies {

View File

@ -0,0 +1,38 @@
package space.kscience.kmath.distributions
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
/**
* A distribution of typed objects.
*/
public interface Distribution<T : Any> : Sampler<T> {
/**
* A probability value for given argument [arg].
* For continuous distributions returns PDF
*/
public fun probability(arg: T): Double
public override fun sample(generator: RandomGenerator): Chain<T>
/**
* An empty companion. Distribution factories should be written as its extensions
*/
public companion object
}
public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
/**
* Cumulative distribution for ordered parameter (CDF)
*/
public fun cumulative(arg: T): Double
}
/**
* Compute probability integral in an interval
*/
public fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
require(to > from)
return cumulative(to) - cumulative(from)
}

View File

@ -1,7 +1,8 @@
package space.kscience.kmath.stat
package space.kscience.kmath.distributions
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.SimpleChain
import space.kscience.kmath.stat.RandomGenerator
/**
* A multivariate distribution which takes a map of parameters

View File

@ -1,12 +1,11 @@
package space.kscience.kmath.stat.distributions
package space.kscience.kmath.distributions
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.internal.InternalErf
import space.kscience.kmath.samplers.GaussianSampler
import space.kscience.kmath.samplers.NormalizedGaussianSampler
import space.kscience.kmath.samplers.ZigguratNormalizedGaussianSampler
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.UnivariateDistribution
import space.kscience.kmath.stat.internal.InternalErf
import space.kscience.kmath.stat.samplers.GaussianSampler
import space.kscience.kmath.stat.samplers.NormalizedGaussianSampler
import space.kscience.kmath.stat.samplers.ZigguratNormalizedGaussianSampler
import kotlin.math.*
/**
@ -16,8 +15,8 @@ public inline class NormalDistribution(public val sampler: GaussianSampler) : Un
public constructor(
mean: Double,
standardDeviation: Double,
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler.of(),
) : this(GaussianSampler.of(mean, standardDeviation, normalized))
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler,
) : this(GaussianSampler(mean, standardDeviation, normalized))
public override fun probability(arg: Double): Double {
val x1 = (arg - sampler.mean) / sampler.standardDeviation

View File

@ -1,4 +1,4 @@
package space.kscience.kmath.stat.internal
package space.kscience.kmath.internal
import kotlin.math.abs

View File

@ -1,4 +1,4 @@
package space.kscience.kmath.stat.internal
package space.kscience.kmath.internal
import kotlin.math.*

View File

@ -1,4 +1,4 @@
package space.kscience.kmath.stat.internal
package space.kscience.kmath.internal
import kotlin.math.ln
import kotlin.math.min

View File

@ -0,0 +1,72 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.ln
import kotlin.math.pow
/**
* Sampling from an [exponential distribution](http://mathworld.wolfram.com/ExponentialDistribution.html).
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html].
*/
public class AhrensDieterExponentialSampler(public val mean: Double) : Sampler<Double> {
init {
require(mean > 0) { "mean is not strictly positive: $mean" }
}
public override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
override fun nextBlocking(): Double {
// Step 1:
var a = 0.0
var u = generator.nextDouble()
// Step 2 and 3:
while (u < 0.5) {
a += EXPONENTIAL_SA_QI[0]
u *= 2.0
}
// Step 4 (now u >= 0.5):
u += u - 1
// Step 5:
if (u <= EXPONENTIAL_SA_QI[0]) return mean * (a + u)
// Step 6:
var i = 0 // Should be 1, be we iterate before it in while using 0.
var u2 = generator.nextDouble()
var umin = u2
// Step 7 and 8:
do {
++i
u2 = generator.nextDouble()
if (u2 < umin) umin = u2
// Step 8:
} while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1.
return mean * (a + umin * EXPONENTIAL_SA_QI[0])
}
override fun nextBufferBlocking(size: Int): DoubleBuffer = DoubleBuffer(size) { nextBlocking() }
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}
public companion object {
private val EXPONENTIAL_SA_QI by lazy {
val ln2 = ln(2.0)
var qi = 0.0
DoubleArray(16) { i ->
qi += ln2.pow(i + 1.0) / space.kscience.kmath.internal.InternalUtils.factorial(i + 1)
qi
}
}
}
}

View File

@ -1,4 +1,4 @@
package space.kscience.kmath.stat.samplers
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.stat.RandomGenerator
@ -80,7 +80,7 @@ public class AhrensDieterMarsagliaTsangGammaSampler private constructor(
private val gaussian: NormalizedGaussianSampler
init {
gaussian = ZigguratNormalizedGaussianSampler.of()
gaussian = ZigguratNormalizedGaussianSampler
dOptim = alpha - ONE_THIRD
cOptim = ONE_THIRD / sqrt(dOptim)
}

View File

@ -1,10 +1,10 @@
package space.kscience.kmath.stat.samplers
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.internal.InternalUtils
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.stat.chain
import space.kscience.kmath.stat.internal.InternalUtils
import kotlin.math.ceil
import kotlin.math.max
import kotlin.math.min
@ -39,12 +39,12 @@ import kotlin.math.min
public open class AliasMethodDiscreteSampler private constructor(
// Deliberate direct storage of input arrays
protected val probability: LongArray,
protected val alias: IntArray
protected val alias: IntArray,
) : Sampler<Int> {
private class SmallTableAliasMethodDiscreteSampler(
probability: LongArray,
alias: IntArray
alias: IntArray,
) : AliasMethodDiscreteSampler(probability, alias) {
// Assume the table size is a power of 2 and create the mask
private val mask: Int = alias.size - 1
@ -111,110 +111,6 @@ public open class AliasMethodDiscreteSampler private constructor(
private const val CONVERT_TO_NUMERATOR: Double = ONE_AS_NUMERATOR.toDouble()
private const val MAX_SMALL_POWER_2_SIZE = 1 shl 11
public fun of(
probabilities: DoubleArray,
alpha: Int = DEFAULT_ALPHA
): Sampler<Int> {
// The Alias method balances N categories with counts around the mean into N sections,
// each allocated 'mean' observations.
//
// Consider 4 categories with counts 6,3,2,1. The histogram can be balanced into a
// 2D array as 4 sections with a height of the mean:
//
// 6
// 6
// 6
// 63 => 6366 --
// 632 6326 |-- mean
// 6321 6321 --
//
// section abcd
//
// Each section is divided as:
// a: 6=1/1
// b: 3=1/1
// c: 2=2/3; 6=1/3 (6 is the alias)
// d: 1=1/3; 6=2/3 (6 is the alias)
//
// The sample is obtained by randomly selecting a section, then choosing which category
// from the pair based on a uniform random deviate.
val sumProb = InternalUtils.validateProbabilities(probabilities)
// Allow zero-padding
val n = computeSize(probabilities.size, alpha)
// Partition into small and large by splitting on the average.
val mean = sumProb / n
// The cardinality of smallSize + largeSize = n.
// So fill the same array from either end.
val indices = IntArray(n)
var large = n
var small = 0
probabilities.indices.forEach { i ->
if (probabilities[i] >= mean) indices[--large] = i else indices[small++] = i
}
small = fillRemainingIndices(probabilities.size, indices, small)
// This may be smaller than the input length if the probabilities were already padded.
val nonZeroIndex = findLastNonZeroIndex(probabilities)
// The probabilities are modified so use a copy.
// Note: probabilities are required only up to last nonZeroIndex
val remainingProbabilities = probabilities.copyOf(nonZeroIndex + 1)
// Allocate the final tables.
// Probability table may be truncated (when zero padded).
// The alias table is full length.
val probability = LongArray(remainingProbabilities.size)
val alias = IntArray(n)
// This loop uses each large in turn to fill the alias table for small probabilities that
// do not reach the requirement to fill an entire section alone (i.e. p < mean).
// Since the sum of the small should be less than the sum of the large it should use up
// all the small first. However floating point round-off can result in
// misclassification of items as small or large. The Vose algorithm handles this using
// a while loop conditioned on the size of both sets and a subsequent loop to use
// unpaired items.
while (large != n && small != 0) {
// Index of the small and the large probabilities.
val j = indices[--small]
val k = indices[large++]
// Optimisation for zero-padded input:
// p(j) = 0 above the last nonZeroIndex
if (j > nonZeroIndex)
// The entire amount for the section is taken from the alias.
remainingProbabilities[k] -= mean
else {
val pj = remainingProbabilities[j]
// Item j is a small probability that is below the mean.
// Compute the weight of the section for item j: pj / mean.
// This is scaled by 2^53 and the ceiling function used to round-up
// the probability to a numerator of a fraction in the range [1,2^53].
// Ceiling ensures non-zero values.
probability[j] = ceil(CONVERT_TO_NUMERATOR * (pj / mean)).toLong()
// The remaining amount for the section is taken from the alias.
// Effectively: probabilities[k] -= (mean - pj)
remainingProbabilities[k] += pj - mean
}
// If not j then the alias is k
alias[j] = k
// Add the remaining probability from large to the appropriate list.
if (remainingProbabilities[k] >= mean) indices[--large] = k else indices[small++] = k
}
// Final loop conditions to consume unpaired items.
// Note: The large set should never be non-empty but this can occur due to round-off
// error so consume from both.
fillTable(probability, alias, indices, 0, small)
fillTable(probability, alias, indices, large, n)
// Change the algorithm for small power of 2 sized tables
return if (isSmallPowerOf2(n))
SmallTableAliasMethodDiscreteSampler(probability, alias)
else
AliasMethodDiscreteSampler(probability, alias)
}
private fun fillRemainingIndices(length: Int, indices: IntArray, small: Int): Int {
var updatedSmall = small
(length until indices.size).forEach { i -> indices[updatedSmall++] = i }
@ -246,7 +142,7 @@ public open class AliasMethodDiscreteSampler private constructor(
alias: IntArray,
indices: IntArray,
start: Int,
end: Int
end: Int,
) = (start until end).forEach { i ->
val index = indices[i]
probability[index] = ONE_AS_NUMERATOR
@ -283,4 +179,110 @@ public open class AliasMethodDiscreteSampler private constructor(
return n - (mutI ushr 1)
}
}
@Suppress("FunctionName")
public fun AliasMethodDiscreteSampler(
probabilities: DoubleArray,
alpha: Int = DEFAULT_ALPHA,
): Sampler<Int> {
// The Alias method balances N categories with counts around the mean into N sections,
// each allocated 'mean' observations.
//
// Consider 4 categories with counts 6,3,2,1. The histogram can be balanced into a
// 2D array as 4 sections with a height of the mean:
//
// 6
// 6
// 6
// 63 => 6366 --
// 632 6326 |-- mean
// 6321 6321 --
//
// section abcd
//
// Each section is divided as:
// a: 6=1/1
// b: 3=1/1
// c: 2=2/3; 6=1/3 (6 is the alias)
// d: 1=1/3; 6=2/3 (6 is the alias)
//
// The sample is obtained by randomly selecting a section, then choosing which category
// from the pair based on a uniform random deviate.
val sumProb = InternalUtils.validateProbabilities(probabilities)
// Allow zero-padding
val n = computeSize(probabilities.size, alpha)
// Partition into small and large by splitting on the average.
val mean = sumProb / n
// The cardinality of smallSize + largeSize = n.
// So fill the same array from either end.
val indices = IntArray(n)
var large = n
var small = 0
probabilities.indices.forEach { i ->
if (probabilities[i] >= mean) indices[--large] = i else indices[small++] = i
}
small = fillRemainingIndices(probabilities.size, indices, small)
// This may be smaller than the input length if the probabilities were already padded.
val nonZeroIndex = findLastNonZeroIndex(probabilities)
// The probabilities are modified so use a copy.
// Note: probabilities are required only up to last nonZeroIndex
val remainingProbabilities = probabilities.copyOf(nonZeroIndex + 1)
// Allocate the final tables.
// Probability table may be truncated (when zero padded).
// The alias table is full length.
val probability = LongArray(remainingProbabilities.size)
val alias = IntArray(n)
// This loop uses each large in turn to fill the alias table for small probabilities that
// do not reach the requirement to fill an entire section alone (i.e. p < mean).
// Since the sum of the small should be less than the sum of the large it should use up
// all the small first. However floating point round-off can result in
// misclassification of items as small or large. The Vose algorithm handles this using
// a while loop conditioned on the size of both sets and a subsequent loop to use
// unpaired items.
while (large != n && small != 0) {
// Index of the small and the large probabilities.
val j = indices[--small]
val k = indices[large++]
// Optimisation for zero-padded input:
// p(j) = 0 above the last nonZeroIndex
if (j > nonZeroIndex)
// The entire amount for the section is taken from the alias.
remainingProbabilities[k] -= mean
else {
val pj = remainingProbabilities[j]
// Item j is a small probability that is below the mean.
// Compute the weight of the section for item j: pj / mean.
// This is scaled by 2^53 and the ceiling function used to round-up
// the probability to a numerator of a fraction in the range [1,2^53].
// Ceiling ensures non-zero values.
probability[j] = ceil(CONVERT_TO_NUMERATOR * (pj / mean)).toLong()
// The remaining amount for the section is taken from the alias.
// Effectively: probabilities[k] -= (mean - pj)
remainingProbabilities[k] += pj - mean
}
// If not j then the alias is k
alias[j] = k
// Add the remaining probability from large to the appropriate list.
if (remainingProbabilities[k] >= mean) indices[--large] = k else indices[small++] = k
}
// Final loop conditions to consume unpaired items.
// Note: The large set should never be non-empty but this can occur due to round-off
// error so consume from both.
fillTable(probability, alias, indices, 0, small)
fillTable(probability, alias, indices, large, n)
// Change the algorithm for small power of 2 sized tables
return if (isSmallPowerOf2(n)) {
SmallTableAliasMethodDiscreteSampler(probability, alias)
} else {
AliasMethodDiscreteSampler(probability, alias)
}
}
}

View File

@ -0,0 +1,52 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.*
/**
* [Box-Muller algorithm](https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform) for sampling from a Gaussian
* distribution.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html].
*/
public object BoxMullerSampler : NormalizedGaussianSampler {
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
var state = Double.NaN
override fun nextBufferBlocking(size: Int): DoubleBuffer {
val xs = generator.nextDoubleBuffer(size)
val ys = generator.nextDoubleBuffer(size)
return DoubleBuffer(size) { index ->
if (state.isNaN()) {
// Generate a pair of Gaussian numbers.
val x = xs[index]
val y = ys[index]
val alpha = 2 * PI * x
val r = sqrt(-2 * ln(y))
// Keep second element of the pair for next invocation.
state = r * sin(alpha)
// Return the first element of the generated pair.
r * cos(alpha)
} else {
// Use the second element of the pair (generated at the
// previous invocation).
state.also {
// Both elements of the pair have been used.
state = Double.NaN
}
}
}
}
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}
}

View File

@ -0,0 +1,13 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingBufferChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.Buffer
public class ConstantSampler<T : Any>(public val const: T) : Sampler<T> {
override fun sample(generator: RandomGenerator): BlockingBufferChain<T> = object : BlockingBufferChain<T> {
override fun nextBufferBlocking(size: Int): Buffer<T> = Buffer.boxing(size) { const }
override suspend fun fork(): BlockingBufferChain<T> = this
}
}

View File

@ -0,0 +1,34 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.chains.map
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
/**
* Sampling from a Gaussian distribution with given mean and standard deviation.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/GaussianSampler.html].
*
* @property mean the mean of the distribution.
* @property standardDeviation the variance of the distribution.
*/
public class GaussianSampler(
public val mean: Double,
public val standardDeviation: Double,
private val normalized: NormalizedGaussianSampler = BoxMullerSampler
) : Sampler<Double> {
init {
require(standardDeviation > 0.0) { "standard deviation is not strictly positive: $standardDeviation" }
}
public override fun sample(generator: RandomGenerator): BlockingDoubleChain = normalized
.sample(generator)
.map { standardDeviation * it + mean }
override fun toString(): String = "N($mean, $standardDeviation)"
public companion object
}

View File

@ -0,0 +1,68 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingIntChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.IntBuffer
import kotlin.math.exp
/**
* Sampler for the Poisson distribution.
* - Kemp, A, W, (1981) Efficient Generation of Logarithmically Distributed Pseudo-Random Variables. Journal of the Royal Statistical Society. Vol. 30, No. 3, pp. 249-253.
* This sampler is suitable for mean < 40. For large means, LargeMeanPoissonSampler should be used instead.
*
* Note: The algorithm uses a recurrence relation to compute the Poisson probability and a rolling summation for the cumulative probability. When the mean is large the initial probability (Math.exp(-mean)) is zero and an exception is raised by the constructor.
*
* Sampling uses 1 call to UniformRandomProvider.nextDouble(). This method provides an alternative to the SmallMeanPoissonSampler for slow generators of double.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/KempSmallMeanPoissonSampler.html].
*/
public class KempSmallMeanPoissonSampler internal constructor(
private val p0: Double,
private val mean: Double,
) : Sampler<Int> {
public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain {
override fun nextBlocking(): Int {
//TODO move to nextBufferBlocking
// Note on the algorithm:
// - X is the unknown sample deviate (the output of the algorithm)
// - x is the current value from the distribution
// - p is the probability of the current value x, p(X=x)
// - u is effectively the cumulative probability that the sample X
// is equal or above the current value x, p(X>=x)
// So if p(X>=x) > p(X=x) the sample must be above x, otherwise it is x
var u = generator.nextDouble()
var x = 0
var p = p0
while (u > p) {
u -= p
// Compute the next probability using a recurrence relation.
// p(x+1) = p(x) * mean / (x+1)
p *= mean / ++x
// The algorithm listed in Kemp (1981) does not check that the rolling probability
// is positive. This check is added to ensure no errors when the limit of the summation
// 1 - sum(p(x)) is above 0 due to cumulative error in floating point arithmetic.
if (p == 0.0) return x
}
return x
}
override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() }
override suspend fun fork(): BlockingIntChain = sample(generator.fork())
}
public override fun toString(): String = "Kemp Small Mean Poisson deviate"
}
public fun KempSmallMeanPoissonSampler(mean: Double): KempSmallMeanPoissonSampler {
require(mean > 0) { "Mean is not strictly positive: $mean" }
val p0 = exp(-mean)
// Probability must be positive. As mean increases then p(0) decreases.
require(p0 > 0) { "No probability for mean: $mean" }
return KempSmallMeanPoissonSampler(p0, mean)
}

View File

@ -0,0 +1,61 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.ln
import kotlin.math.sqrt
/**
* [Marsaglia polar method](https://en.wikipedia.org/wiki/Marsaglia_polar_method) for sampling from a Gaussian
* distribution with mean 0 and standard deviation 1. This is a variation of the algorithm implemented in
* [BoxMullerNormalizedGaussianSampler].
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/MarsagliaNormalizedGaussianSampler.html]
*/
public object MarsagliaNormalizedGaussianSampler : NormalizedGaussianSampler {
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
var nextGaussian = Double.NaN
override fun nextBlocking(): Double {
return if (nextGaussian.isNaN()) {
val alpha: Double
var x: Double
// Rejection scheme for selecting a pair that lies within the unit circle.
while (true) {
// Generate a pair of numbers within [-1 , 1).
x = 2.0 * generator.nextDouble() - 1.0
val y = 2.0 * generator.nextDouble() - 1.0
val r2 = x * x + y * y
if (r2 < 1 && r2 > 0) {
// Pair (x, y) is within unit circle.
alpha = sqrt(-2 * ln(r2) / r2)
// Keep second element of the pair for next invocation.
nextGaussian = alpha * y
// Return the first element of the generated pair.
break
}
// Pair is not within the unit circle: Generate another one.
}
// Return the first element of the generated pair.
alpha * x
} else {
// Use the second element of the pair (generated at the
// previous invocation).
val r = nextGaussian
// Both elements of the pair have been used.
nextGaussian = Double.NaN
r
}
}
override fun nextBufferBlocking(size: Int): DoubleBuffer = DoubleBuffer(size) { nextBlocking() }
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}
}

View File

@ -0,0 +1,18 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
public interface BlockingDoubleSampler: Sampler<Double>{
override fun sample(generator: RandomGenerator): BlockingDoubleChain
}
/**
* Marker interface for a sampler that generates values from an N(0,1)
* [Gaussian distribution](https://en.wikipedia.org/wiki/Normal_distribution).
*/
public fun interface NormalizedGaussianSampler : BlockingDoubleSampler{
public companion object
}

View File

@ -0,0 +1,203 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingIntChain
import space.kscience.kmath.internal.InternalUtils
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.structures.IntBuffer
import kotlin.math.*
private const val PIVOT = 40.0
/**
* Sampler for the Poisson distribution.
* - For small means, a Poisson process is simulated using uniform deviates, as described in
* Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3
* Important integer-valued distributions: The Poisson distribution. Addison Wesley.
* The Poisson process (and hence, the returned value) is bounded by 1000 * mean.
* - For large means, we use the rejection algorithm described in
* Devroye, Luc. (1981). The Computer Generation of Poisson Random Variables Computing vol. 26 pp. 197-207.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/PoissonSampler.html].
*/
@Suppress("FunctionName")
public fun PoissonSampler(mean: Double): Sampler<Int> {
return if (mean < PIVOT) SmallMeanPoissonSampler(mean) else LargeMeanPoissonSampler(mean)
}
/**
* Sampler for the Poisson distribution.
* - For small means, a Poisson process is simulated using uniform deviates, as described in
* Knuth (1969). Seminumerical Algorithms. The Art of Computer Programming, Volume 2. Chapter 3.4.1.F.3 Important
* integer-valued distributions: The Poisson distribution. Addison Wesley.
* - The Poisson process (and hence, the returned value) is bounded by 1000 * mean.
* This sampler is suitable for mean < 40. For large means, [LargeMeanPoissonSampler] should be used instead.
*
* Based on Commons RNG implementation.
*
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/SmallMeanPoissonSampler.html].
*/
public class SmallMeanPoissonSampler(public val mean: Double) : Sampler<Int> {
init {
require(mean > 0) { "mean is not strictly positive: $mean" }
}
private val p0: Double = exp(-mean)
private val limit: Int = if (p0 > 0) {
ceil(1000 * mean)
} else {
throw IllegalArgumentException("No p(x=0) probability for mean: $mean")
}.toInt()
public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain {
override fun nextBlocking(): Int {
var n = 0
var r = 1.0
while (n < limit) {
r *= generator.nextDouble()
if (r >= p0) n++ else break
}
return n
}
override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() }
override suspend fun fork(): BlockingIntChain = sample(generator.fork())
}
public override fun toString(): String = "Small Mean Poisson deviate"
}
/**
* Sampler for the Poisson distribution.
* - For large means, we use the rejection algorithm described in
* Devroye, Luc. (1981).The Computer Generation of Poisson Random Variables
* Computing vol. 26 pp. 197-207.
*
* This sampler is suitable for mean >= 40.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.html].
*/
public class LargeMeanPoissonSampler(public val mean: Double) : Sampler<Int> {
init {
require(mean >= 1) { "mean is not >= 1: $mean" }
// The algorithm is not valid if Math.floor(mean) is not an integer.
require(mean <= MAX_MEAN) { "mean $mean > $MAX_MEAN" }
}
private val factorialLog: InternalUtils.FactorialLog = NO_CACHE_FACTORIAL_LOG
private val lambda: Double = floor(mean)
private val logLambda: Double = ln(lambda)
private val logLambdaFactorial: Double = getFactorialLog(lambda.toInt())
private val delta: Double = sqrt(lambda * ln(32 * lambda / PI + 1))
private val halfDelta: Double = delta / 2
private val twolpd: Double = 2 * lambda + delta
private val c1: Double = 1 / (8 * lambda)
private val a1: Double = sqrt(PI * twolpd) * exp(c1)
private val a2: Double = twolpd / delta * exp(-delta * (1 + delta) / twolpd)
private val aSum: Double = a1 + a2 + 1
private val p1: Double = a1 / aSum
private val p2: Double = a2 / aSum
public override fun sample(generator: RandomGenerator): BlockingIntChain = object : BlockingIntChain {
override fun nextBlocking(): Int {
val exponential = AhrensDieterExponentialSampler(1.0).sample(generator)
val gaussian = ZigguratNormalizedGaussianSampler.sample(generator)
val smallMeanPoissonSampler = if (mean - lambda < Double.MIN_VALUE) {
null
} else {
KempSmallMeanPoissonSampler(mean - lambda).sample(generator)
}
val y2 = smallMeanPoissonSampler?.nextBlocking() ?: 0
var x: Double
var y: Double
var v: Double
var a: Int
var t: Double
var qr: Double
var qa: Double
while (true) {
// Step 1:
val u = generator.nextDouble()
if (u <= p1) {
// Step 2:
val n = gaussian.nextBlocking()
x = n * sqrt(lambda + halfDelta) - 0.5
if (x > delta || x < -lambda) continue
y = if (x < 0) floor(x) else ceil(x)
val e = exponential.nextBlocking()
v = -e - 0.5 * n * n + c1
} else {
// Step 3:
if (u > p1 + p2) {
y = lambda
break
}
x = delta + twolpd / delta * exponential.nextBlocking()
y = ceil(x)
v = -exponential.nextBlocking() - delta * (x + 1) / twolpd
}
// The Squeeze Principle
// Step 4.1:
a = if (x < 0) 1 else 0
t = y * (y + 1) / (2 * lambda)
// Step 4.2
if (v < -t && a == 0) {
y += lambda
break
}
// Step 4.3:
qr = t * ((2 * y + 1) / (6 * lambda) - 1)
qa = qr - t * t / (3 * (lambda + a * (y + 1)))
// Step 4.4:
if (v < qa) {
y += lambda
break
}
// Step 4.5:
if (v > qr) continue
// Step 4.6:
if (v < y * logLambda - getFactorialLog((y + lambda).toInt()) + logLambdaFactorial) {
y += lambda
break
}
}
return min(y2 + y.toLong(), Int.MAX_VALUE.toLong()).toInt()
}
override fun nextBufferBlocking(size: Int): IntBuffer = IntBuffer(size) { nextBlocking() }
override suspend fun fork(): BlockingIntChain = sample(generator.fork())
}
private fun getFactorialLog(n: Int): Double = factorialLog.value(n)
public override fun toString(): String = "Large Mean Poisson deviate"
public companion object {
private const val MAX_MEAN: Double = 0.5 * Int.MAX_VALUE
private val NO_CACHE_FACTORIAL_LOG: InternalUtils.FactorialLog = InternalUtils.FactorialLog.create()
}
}

View File

@ -0,0 +1,88 @@
package space.kscience.kmath.samplers
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.*
/**
* [Marsaglia and Tsang "Ziggurat"](https://en.wikipedia.org/wiki/Ziggurat_algorithm) method for sampling from a
* Gaussian distribution with mean 0 and standard deviation 1. The algorithm is explained in this paper and this
* implementation has been adapted from the C code provided therein.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/ZigguratNormalizedGaussianSampler.html].
*/
public object ZigguratNormalizedGaussianSampler : NormalizedGaussianSampler {
private const val R: Double = 3.442619855899
private const val ONE_OVER_R: Double = 1 / R
private const val V: Double = 9.91256303526217e-3
private val MAX: Double = 2.0.pow(63.0)
private val ONE_OVER_MAX: Double = 1.0 / MAX
private const val LEN: Int = 128
private const val LAST: Int = LEN - 1
private val K: LongArray = LongArray(LEN)
private val W: DoubleArray = DoubleArray(LEN)
private val F: DoubleArray = DoubleArray(LEN)
init {
// Filling the tables.
var d = R
var t = d
var fd = gauss(d)
val q = V / fd
K[0] = (d / q * MAX).toLong()
K[1] = 0
W[0] = q * ONE_OVER_MAX
W[LAST] = d * ONE_OVER_MAX
F[0] = 1.0
F[LAST] = fd
(LAST - 1 downTo 1).forEach { i ->
d = sqrt(-2 * ln(V / d + fd))
fd = gauss(d)
K[i + 1] = (d / t * MAX).toLong()
t = d
F[i] = fd
W[i] = d * ONE_OVER_MAX
}
}
private fun gauss(x: Double): Double = exp(-0.5 * x * x)
private fun sampleOne(generator: RandomGenerator): Double {
val j = generator.nextLong()
val i = (j and LAST.toLong()).toInt()
return if (abs(j) < K[i]) j * W[i] else fix(generator, j, i)
}
override fun sample(generator: RandomGenerator): BlockingDoubleChain = object : BlockingDoubleChain {
override fun nextBufferBlocking(size: Int): DoubleBuffer = DoubleBuffer(size) { sampleOne(generator) }
override suspend fun fork(): BlockingDoubleChain = sample(generator.fork())
}
private fun fix(generator: RandomGenerator, hz: Long, iz: Int): Double {
var x = hz * W[iz]
return when {
iz == 0 -> {
var y: Double
do {
y = -ln(generator.nextDouble())
x = -ln(generator.nextDouble()) * ONE_OVER_R
} while (y + y < x * x)
val out = R + x
if (hz > 0) out else -out
}
F[iz] + generator.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x) -> x
else -> sampleOne(generator)
}
}
}

View File

@ -1,8 +1,8 @@
package space.kscience.kmath.stat
import space.kscience.kmath.chains.BlockingDoubleChain
import space.kscience.kmath.chains.BlockingIntChain
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.structures.DoubleBuffer
/**
* A possibly stateful chain producing random values.
@ -11,12 +11,24 @@ import space.kscience.kmath.chains.Chain
*/
public class RandomChain<out R>(
public val generator: RandomGenerator,
private val gen: suspend RandomGenerator.() -> R
private val gen: suspend RandomGenerator.() -> R,
) : Chain<R> {
override suspend fun next(): R = generator.gen()
override fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
override suspend fun fork(): Chain<R> = RandomChain(generator.fork(), gen)
}
/**
* Create a generic random chain with provided [generator]
*/
public fun <R> RandomGenerator.chain(generator: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, generator)
/**
* A type-specific double chunk random chain
*/
public class UniformDoubleChain(public val generator: RandomGenerator) : BlockingDoubleChain {
public override fun nextBufferBlocking(size: Int): DoubleBuffer = generator.nextDoubleBuffer(size)
override suspend fun nextBuffer(size: Int): DoubleBuffer = nextBufferBlocking(size)
override suspend fun fork(): UniformDoubleChain = UniformDoubleChain(generator.fork())
}
public fun <R> RandomGenerator.chain(gen: suspend RandomGenerator.() -> R): RandomChain<R> = RandomChain(this, gen)
public fun Chain<Double>.blocking(): BlockingDoubleChain = object : Chain<Double> by this, BlockingDoubleChain {}
public fun Chain<Int>.blocking(): BlockingIntChain = object : Chain<Int> by this, BlockingIntChain {}

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.stat
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.random.Random
/**
@ -16,6 +17,11 @@ public interface RandomGenerator {
*/
public fun nextDouble(): Double
/**
* A chunk of doubles of given [size]
*/
public fun nextDoubleBuffer(size: Int): DoubleBuffer = DoubleBuffer(size) { nextDouble() }
/**
* Gets the next random `Int` from the random number generator.
*

View File

@ -3,16 +3,13 @@ package space.kscience.kmath.stat
import kotlinx.coroutines.flow.first
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.collect
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.IntBuffer
import space.kscience.kmath.structures.MutableBuffer
import space.kscience.kmath.structures.*
import kotlin.jvm.JvmName
/**
* Sampler that generates chains of values of type [T].
* Sampler that generates chains of values of type [T] in a chain of type [C].
*/
public fun interface Sampler<T : Any> {
public fun interface Sampler<out T : Any> {
/**
* Generates a chain of samples.
*
@ -22,39 +19,6 @@ public fun interface Sampler<T : Any> {
public fun sample(generator: RandomGenerator): Chain<T>
}
/**
* A distribution of typed objects.
*/
public interface Distribution<T : Any> : Sampler<T> {
/**
* A probability value for given argument [arg].
* For continuous distributions returns PDF
*/
public fun probability(arg: T): Double
public override fun sample(generator: RandomGenerator): Chain<T>
/**
* An empty companion. Distribution factories should be written as its extensions
*/
public companion object
}
public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
/**
* Cumulative distribution for ordered parameter (CDF)
*/
public fun cumulative(arg: T): Double
}
/**
* Compute probability integral in an interval
*/
public fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
require(to > from)
return cumulative(to) - cumulative(from)
}
/**
* Sample a bunch of values
*/
@ -71,7 +35,7 @@ public fun <T : Any> Sampler<T>.sampleBuffer(
//clear list from previous run
tmp.clear()
//Fill list
repeat(size) { tmp += chain.next() }
repeat(size) { tmp.add(chain.next()) }
//return new buffer with elements from tmp
bufferFactory(size) { tmp[it] }
}
@ -87,7 +51,7 @@ public suspend fun <T : Any> Sampler<T>.next(generator: RandomGenerator): T = sa
*/
@JvmName("sampleRealBuffer")
public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Double>> =
sampleBuffer(generator, size, MutableBuffer.Companion::double)
sampleBuffer(generator, size, ::DoubleBuffer)
/**
* Generates [size] integer samples and chunks them into some buffers.

View File

@ -81,7 +81,7 @@ public class Mean<T>(
public companion object {
//TODO replace with optimized version which respects overflow
public val real: Mean<Double> = Mean(DoubleField) { sum, count -> sum / count }
public val double: Mean<Double> = Mean(DoubleField) { sum, count -> sum / count }
public val int: Mean<Int> = Mean(IntRing) { sum, count -> sum / count }
public val long: Mean<Long> = Mean(LongRing) { sum, count -> sum / count }
}

View File

@ -2,6 +2,8 @@ package space.kscience.kmath.stat
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.SimpleChain
import space.kscience.kmath.distributions.Distribution
import space.kscience.kmath.distributions.UnivariateDistribution
public class UniformDistribution(public val range: ClosedFloatingPointRange<Double>) : UnivariateDistribution<Double> {
private val length: Double = range.endInclusive - range.start

View File

@ -1,73 +0,0 @@
package space.kscience.kmath.stat.samplers
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.stat.chain
import space.kscience.kmath.stat.internal.InternalUtils
import kotlin.math.ln
import kotlin.math.pow
/**
* Sampling from an [exponential distribution](http://mathworld.wolfram.com/ExponentialDistribution.html).
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/AhrensDieterExponentialSampler.html].
*/
public class AhrensDieterExponentialSampler private constructor(public val mean: Double) : Sampler<Double> {
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
// Step 1:
var a = 0.0
var u = nextDouble()
// Step 2 and 3:
while (u < 0.5) {
a += EXPONENTIAL_SA_QI[0]
u *= 2.0
}
// Step 4 (now u >= 0.5):
u += u - 1
// Step 5:
if (u <= EXPONENTIAL_SA_QI[0]) return@chain mean * (a + u)
// Step 6:
var i = 0 // Should be 1, be we iterate before it in while using 0.
var u2 = nextDouble()
var umin = u2
// Step 7 and 8:
do {
++i
u2 = nextDouble()
if (u2 < umin) umin = u2
// Step 8:
} while (u > EXPONENTIAL_SA_QI[i]) // Ensured to exit since EXPONENTIAL_SA_QI[MAX] = 1.
mean * (a + umin * EXPONENTIAL_SA_QI[0])
}
override fun toString(): String = "Ahrens-Dieter Exponential deviate"
public companion object {
private val EXPONENTIAL_SA_QI by lazy { DoubleArray(16) }
init {
/**
* Filling EXPONENTIAL_SA_QI table.
* Note that we don't want qi = 0 in the table.
*/
val ln2 = ln(2.0)
var qi = 0.0
EXPONENTIAL_SA_QI.indices.forEach { i ->
qi += ln2.pow(i + 1.0) / InternalUtils.factorial(i + 1)
EXPONENTIAL_SA_QI[i] = qi
}
}
public fun of(mean: Double): AhrensDieterExponentialSampler {
require(mean > 0) { "mean is not strictly positive: $mean" }
return AhrensDieterExponentialSampler(mean)
}
}
}

View File

@ -1,48 +0,0 @@
package space.kscience.kmath.stat.samplers
import space.kscience.kmath.chains.Chain
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.Sampler
import space.kscience.kmath.stat.chain
import kotlin.math.*
/**
* [Box-Muller algorithm](https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform) for sampling from a Gaussian
* distribution.
*
* Based on Commons RNG implementation.
* See [https://commons.apache.org/proper/commons-rng/commons-rng-sampling/apidocs/org/apache/commons/rng/sampling/distribution/BoxMullerNormalizedGaussianSampler.html].
*/
public class BoxMullerNormalizedGaussianSampler private constructor() : NormalizedGaussianSampler, Sampler<Double> {
private var nextGaussian: Double = Double.NaN
public override fun sample(generator: RandomGenerator): Chain<Double> = generator.chain {
val random: Double
if (nextGaussian.isNaN()) {
// Generate a pair of Gaussian numbers.
val x = nextDouble()
val y = nextDouble()
val alpha = 2 * PI * x
val r = sqrt(-2 * ln(y))
// Return the first element of the generated pair.
random = r * cos(alpha)
// Keep second element of the pair for next invocation.
nextGaussian = r * sin(alpha)
} else {
// Use the second element of the pair (generated at the
// previous invocation).
random = nextGaussian
// Both elements of the pair have been used.
nextGaussian = Double.NaN
}
random
}
public override fun toString(): String = "Box-Muller normalized Gaussian deviate"
public companion object {
public fun of(): BoxMullerNormalizedGaussianSampler = BoxMullerNormalizedGaussianSampler()
}
}

Some files were not shown because too many files have changed in this diff Show More