diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c6b14b95..2ff5dd623 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,18 +2,19 @@ ## [Unreleased] ### Added -- ScaleOperations interface -- Field extends ScaleOperations +- `ScaleOperations` interface +- `Field` extends `ScaleOperations` - Basic integration API - Basic MPP distributions and samplers -- bindSymbolOrNull +- `bindSymbolOrNull` - Blocking chains and Statistics - Multiplatform integration - Integration for any Field element - Extended operations for ND4J fields - Jupyter Notebook integration module (kmath-jupyter) - `@PerformancePitfall` annotation to mark possibly slow API -- BigInt operation performance improvement and fixes by @zhelenskiy (#328) +- `BigInt` operation performance improvement and fixes by @zhelenskiy (#328) +- Integration between `MST` and Symja `IExpr` ### Changed - Exponential operations merged with hyperbolic functions diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index d095db1ba..568bde153 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -26,7 +26,7 @@ dependencies { implementation(project(":kmath-ejml")) implementation(project(":kmath-nd4j")) implementation(project(":kmath-tensors")) - + implementation(project(":kmath-symja")) implementation(project(":kmath-for-real")) implementation("org.nd4j:nd4j-native:1.0.0-beta7") diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt index 420b23f9f..6ceaa962a 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/kotlingradSupport.kt @@ -5,25 +5,23 @@ package space.kscience.kmath.ast -import space.kscience.kmath.asm.compileToExpression import space.kscience.kmath.expressions.derivative import space.kscience.kmath.expressions.invoke -import space.kscience.kmath.expressions.symbol -import space.kscience.kmath.kotlingrad.toDiffExpression +import space.kscience.kmath.expressions.Symbol.Companion.x +import space.kscience.kmath.expressions.toExpression +import space.kscience.kmath.kotlingrad.toKotlingradExpression import space.kscience.kmath.operations.DoubleField /** * In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with - * valid derivative. + * valid derivative in a certain point. */ fun main() { - val x by symbol - - val actualDerivative = "x^2-4*x-44".parseMath() - .toDiffExpression(DoubleField) + val actualDerivative = "x^2-4*x-44" + .parseMath() + .toKotlingradExpression(DoubleField) .derivative(x) - - val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField) - assert(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0)) + val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField) + check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0)) } diff --git a/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt b/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt new file mode 100644 index 000000000..a9eca0500 --- /dev/null +++ b/examples/src/main/kotlin/space/kscience/kmath/ast/symjaSupport.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.ast + +import space.kscience.kmath.expressions.Symbol.Companion.x +import space.kscience.kmath.expressions.derivative +import space.kscience.kmath.expressions.invoke +import space.kscience.kmath.expressions.toExpression +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.symja.toSymjaExpression + +/** + * In this example, x^2-4*x-44 function is differentiated with Symja, and the autodiff result is compared with + * valid derivative in a certain point. + */ +fun main() { + val actualDerivative = "x^2-4*x-44" + .parseMath() + .toSymjaExpression(DoubleField) + .derivative(x) + + val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField) + check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0)) +} diff --git a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt index 72ecee4f1..8312eaff3 100644 --- a/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt +++ b/kmath-kotlingrad/src/main/kotlin/space/kscience/kmath/kotlingrad/KotlingradExpression.kt @@ -39,7 +39,7 @@ public class KotlingradExpression>( } /** - * Wraps this [MST] into [KotlingradExpression]. + * Wraps this [MST] into [KotlingradExpression] in the context of [algebra]. */ -public fun > MST.toDiffExpression(algebra: A): KotlingradExpression = +public fun > MST.toKotlingradExpression(algebra: A): KotlingradExpression = KotlingradExpression(algebra, this) diff --git a/kmath-symja/build.gradle.kts b/kmath-symja/build.gradle.kts new file mode 100644 index 000000000..c06020bb6 --- /dev/null +++ b/kmath-symja/build.gradle.kts @@ -0,0 +1,21 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +plugins { + kotlin("jvm") + id("ru.mipt.npm.gradle.common") +} + +description = "Symja integration module" + +dependencies { + api("org.matheclipse:matheclipse-core:2.0.0-SNAPSHOT") + api(project(":kmath-core")) + testImplementation("org.slf4j:slf4j-simple:1.7.30") +} + +readme { + maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE +} diff --git a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt new file mode 100644 index 000000000..301678916 --- /dev/null +++ b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/SymjaExpression.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.symja + +import org.matheclipse.core.eval.ExprEvaluator +import org.matheclipse.core.expression.F +import space.kscience.kmath.expressions.DifferentiableExpression +import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.expressions.interpret +import space.kscience.kmath.operations.NumericAlgebra + +public class SymjaExpression>( + public val algebra: A, + public val mst: MST, + public val evaluator: ExprEvaluator = DEFAULT_EVALUATOR, +) : DifferentiableExpression> { + public override fun invoke(arguments: Map): T = mst.interpret(algebra, arguments) + + public override fun derivativeOrNull(symbols: List): SymjaExpression = SymjaExpression( + algebra, + symbols.map(Symbol::toIExpr).fold(mst.toIExpr(), F::D).toMst(evaluator), + evaluator, + ) +} + +/** + * Wraps this [MST] into [SymjaExpression] in the context of [algebra]. + */ +public fun > MST.toSymjaExpression(algebra: A): SymjaExpression = + SymjaExpression(algebra, this) diff --git a/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt new file mode 100644 index 000000000..95dd1ebbf --- /dev/null +++ b/kmath-symja/src/main/kotlin/space/kscience/kmath/symja/adapters.kt @@ -0,0 +1,95 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.symja + +import org.matheclipse.core.eval.ExprEvaluator +import org.matheclipse.core.expression.ComplexNum +import org.matheclipse.core.expression.F +import org.matheclipse.core.interfaces.IExpr +import org.matheclipse.core.interfaces.ISymbol +import space.kscience.kmath.expressions.MST +import space.kscience.kmath.expressions.MstExtendedField +import space.kscience.kmath.expressions.Symbol +import space.kscience.kmath.operations.* + +internal val DEFAULT_EVALUATOR = ExprEvaluator(false, 100) + +/** + * Matches the given [IExpr] instance to appropriate [MST] node or evaluates it with [evaluator]. + */ +public fun IExpr.toMst(evaluator: ExprEvaluator = DEFAULT_EVALUATOR): MST = MstExtendedField { + when { + isPlus -> first().toMst(evaluator) + second().toMst(evaluator) + isSin -> sin(first().toMst(evaluator)) + isSinh -> sinh(first().toMst(evaluator)) + isCos -> cos(first().toMst(evaluator)) + isCosh -> cosh(first().toMst(evaluator)) + isTan -> tan(first().toMst(evaluator)) + isTanh -> tanh(first().toMst(evaluator)) + isArcSin -> asin(first().toMst(evaluator)) + isArcCos -> acos(first().toMst(evaluator)) + isArcTan -> atan(first().toMst(evaluator)) + isArcTanh -> atanh(first().toMst(evaluator)) + isE -> bindSymbol("e") + isPi -> bindSymbol("pi") + isTimes -> first().toMst(evaluator) * second().toMst(evaluator) + isOne -> one + isZero -> zero + isImaginaryUnit -> bindSymbol("i") + isMinusOne -> -one + this@toMst is ISymbol -> bindSymbol(symbolName) + isPower -> power(first().toMst(evaluator), evaluator.evalf(second())) + isExp -> exp(first().toMst(evaluator)) + isNumber -> number(evaluator.evalf(this@toMst)) + this@toMst === F.NIL -> error("NIL cannot be converted to MST") + else -> evaluator.eval(this@toMst.toString()).toMst(evaluator) + } +} + +/** + * Matches the given [MST] instance to appropriate [IExpr] node, only standard operations and symbols (which are + * present in, say, [MstExtendedField]) are supported. + */ +public fun MST.toIExpr(): IExpr = when (this) { + is MST.Numeric -> F.symjify(value) + + is Symbol -> when (identity) { + "e" -> F.E + "pi" -> F.Pi + "i" -> ComplexNum.I + else -> F.Dummy(identity) + } + + is MST.Unary -> when (operation) { + GroupOperations.PLUS_OPERATION -> value.toIExpr() + GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr()) + TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr()) + TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr()) + TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr()) + TrigonometricOperations.ASIN_OPERATION -> F.ArcSin(value.toIExpr()) + TrigonometricOperations.ACOS_OPERATION -> F.ArcCos(value.toIExpr()) + TrigonometricOperations.ATAN_OPERATION -> F.ArcTan(value.toIExpr()) + ExponentialOperations.SINH_OPERATION -> F.Sinh(value.toIExpr()) + ExponentialOperations.COSH_OPERATION -> F.Cosh(value.toIExpr()) + ExponentialOperations.TANH_OPERATION -> F.Tanh(value.toIExpr()) + ExponentialOperations.ASINH_OPERATION -> F.ArcSinh(value.toIExpr()) + ExponentialOperations.ACOSH_OPERATION -> F.ArcCosh(value.toIExpr()) + ExponentialOperations.ATANH_OPERATION -> F.ArcTanh(value.toIExpr()) + PowerOperations.SQRT_OPERATION -> F.Sqrt(value.toIExpr()) + ExponentialOperations.EXP_OPERATION -> F.Exp(value.toIExpr()) + ExponentialOperations.LN_OPERATION -> F.Log(value.toIExpr()) + else -> error("Unary operation $operation not defined in $this") + } + + is MST.Binary -> when (operation) { + GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() + GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() + RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() + FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) + PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value)) + else -> error("Binary operation $operation not defined in $this") + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index cc9eb4860..d47aea2a7 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -42,6 +42,7 @@ include( ":kmath-kotlingrad", ":kmath-tensors", ":kmath-jupyter", + ":kmath-symja", ":examples", ":benchmarks" )