Merge pull request #351 from mipt-npm/commandertvis/symja

Integration between MST and Symja IExpr
This commit is contained in:
Alexander Nozik 2021-05-26 18:01:25 +03:00 committed by GitHub
commit 8da331f476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 195 additions and 18 deletions

View File

@ -2,18 +2,19 @@
## [Unreleased] ## [Unreleased]
### Added ### Added
- ScaleOperations interface - `ScaleOperations` interface
- Field extends ScaleOperations - `Field` extends `ScaleOperations`
- Basic integration API - Basic integration API
- Basic MPP distributions and samplers - Basic MPP distributions and samplers
- bindSymbolOrNull - `bindSymbolOrNull`
- Blocking chains and Statistics - Blocking chains and Statistics
- Multiplatform integration - Multiplatform integration
- Integration for any Field element - Integration for any Field element
- Extended operations for ND4J fields - Extended operations for ND4J fields
- Jupyter Notebook integration module (kmath-jupyter) - Jupyter Notebook integration module (kmath-jupyter)
- `@PerformancePitfall` annotation to mark possibly slow API - `@PerformancePitfall` annotation to mark possibly slow API
- BigInt operation performance improvement and fixes by @zhelenskiy (#328) - `BigInt` operation performance improvement and fixes by @zhelenskiy (#328)
- Integration between `MST` and Symja `IExpr`
### Changed ### Changed
- Exponential operations merged with hyperbolic functions - Exponential operations merged with hyperbolic functions

View File

@ -26,7 +26,7 @@ dependencies {
implementation(project(":kmath-ejml")) implementation(project(":kmath-ejml"))
implementation(project(":kmath-nd4j")) implementation(project(":kmath-nd4j"))
implementation(project(":kmath-tensors")) implementation(project(":kmath-tensors"))
implementation(project(":kmath-symja"))
implementation(project(":kmath-for-real")) implementation(project(":kmath-for-real"))
implementation("org.nd4j:nd4j-native:1.0.0-beta7") implementation("org.nd4j:nd4j-native:1.0.0-beta7")

View File

@ -5,25 +5,23 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.kmath.asm.compileToExpression
import space.kscience.kmath.expressions.derivative import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.expressions.symbol import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.kotlingrad.toDiffExpression import space.kscience.kmath.expressions.toExpression
import space.kscience.kmath.kotlingrad.toKotlingradExpression
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
/** /**
* In this example, x^2-4*x-44 function is differentiated with Kotlin, and the autodiff result is compared with * In this example, x^2-4*x-44 function is differentiated with Kotlin, and the autodiff result is compared with
* valid derivative. * valid derivative in a certain point.
*/ */
fun main() { fun main() {
val x by symbol val actualDerivative = "x^2-4*x-44"
.parseMath()
val actualDerivative = "x^2-4*x-44".parseMath() .toKotlingradExpression(DoubleField)
.toDiffExpression(DoubleField)
.derivative(x) .derivative(x)
val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField)
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField) check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
assert(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
} }

View File

@ -0,0 +1,27 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Symbol.Companion.x
import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.expressions.toExpression
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.symja.toSymjaExpression
/**
* In this example, x^2-4*x-44 function is differentiated with Symja, and the autodiff result is compared with
* valid derivative in a certain point.
*/
fun main() {
val actualDerivative = "x^2-4*x-44"
.parseMath()
.toSymjaExpression(DoubleField)
.derivative(x)
val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField)
check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
}

View File

@ -39,7 +39,7 @@ public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
} }
/** /**
* Wraps this [MST] into [KotlingradExpression]. * Wraps this [MST] into [KotlingradExpression] in the context of [algebra].
*/ */
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> = public fun <T : Number, A : NumericAlgebra<T>> MST.toKotlingradExpression(algebra: A): KotlingradExpression<T, A> =
KotlingradExpression(algebra, this) KotlingradExpression(algebra, this)

View File

@ -0,0 +1,21 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
plugins {
kotlin("jvm")
id("ru.mipt.npm.gradle.common")
}
description = "Symja integration module"
dependencies {
api("org.matheclipse:matheclipse-core:2.0.0-SNAPSHOT")
api(project(":kmath-core"))
testImplementation("org.slf4j:slf4j-simple:1.7.30")
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
}

View File

@ -0,0 +1,34 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.symja
import org.matheclipse.core.eval.ExprEvaluator
import org.matheclipse.core.expression.F
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.operations.NumericAlgebra
public class SymjaExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
public val mst: MST,
public val evaluator: ExprEvaluator = DEFAULT_EVALUATOR,
) : DifferentiableExpression<T, SymjaExpression<T, A>> {
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): SymjaExpression<T, A> = SymjaExpression(
algebra,
symbols.map(Symbol::toIExpr).fold(mst.toIExpr(), F::D).toMst(evaluator),
evaluator,
)
}
/**
* Wraps this [MST] into [SymjaExpression] in the context of [algebra].
*/
public fun <T : Number, A : NumericAlgebra<T>> MST.toSymjaExpression(algebra: A): SymjaExpression<T, A> =
SymjaExpression(algebra, this)

View File

@ -0,0 +1,95 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.symja
import org.matheclipse.core.eval.ExprEvaluator
import org.matheclipse.core.expression.ComplexNum
import org.matheclipse.core.expression.F
import org.matheclipse.core.interfaces.IExpr
import org.matheclipse.core.interfaces.ISymbol
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.*
internal val DEFAULT_EVALUATOR = ExprEvaluator(false, 100)
/**
* Matches the given [IExpr] instance to appropriate [MST] node or evaluates it with [evaluator].
*/
public fun IExpr.toMst(evaluator: ExprEvaluator = DEFAULT_EVALUATOR): MST = MstExtendedField {
when {
isPlus -> first().toMst(evaluator) + second().toMst(evaluator)
isSin -> sin(first().toMst(evaluator))
isSinh -> sinh(first().toMst(evaluator))
isCos -> cos(first().toMst(evaluator))
isCosh -> cosh(first().toMst(evaluator))
isTan -> tan(first().toMst(evaluator))
isTanh -> tanh(first().toMst(evaluator))
isArcSin -> asin(first().toMst(evaluator))
isArcCos -> acos(first().toMst(evaluator))
isArcTan -> atan(first().toMst(evaluator))
isArcTanh -> atanh(first().toMst(evaluator))
isE -> bindSymbol("e")
isPi -> bindSymbol("pi")
isTimes -> first().toMst(evaluator) * second().toMst(evaluator)
isOne -> one
isZero -> zero
isImaginaryUnit -> bindSymbol("i")
isMinusOne -> -one
this@toMst is ISymbol -> bindSymbol(symbolName)
isPower -> power(first().toMst(evaluator), evaluator.evalf(second()))
isExp -> exp(first().toMst(evaluator))
isNumber -> number(evaluator.evalf(this@toMst))
this@toMst === F.NIL -> error("NIL cannot be converted to MST")
else -> evaluator.eval(this@toMst.toString()).toMst(evaluator)
}
}
/**
* Matches the given [MST] instance to appropriate [IExpr] node, only standard operations and symbols (which are
* present in, say, [MstExtendedField]) are supported.
*/
public fun MST.toIExpr(): IExpr = when (this) {
is MST.Numeric -> F.symjify(value)
is Symbol -> when (identity) {
"e" -> F.E
"pi" -> F.Pi
"i" -> ComplexNum.I
else -> F.Dummy(identity)
}
is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> value.toIExpr()
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr())
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
TrigonometricOperations.ASIN_OPERATION -> F.ArcSin(value.toIExpr())
TrigonometricOperations.ACOS_OPERATION -> F.ArcCos(value.toIExpr())
TrigonometricOperations.ATAN_OPERATION -> F.ArcTan(value.toIExpr())
ExponentialOperations.SINH_OPERATION -> F.Sinh(value.toIExpr())
ExponentialOperations.COSH_OPERATION -> F.Cosh(value.toIExpr())
ExponentialOperations.TANH_OPERATION -> F.Tanh(value.toIExpr())
ExponentialOperations.ASINH_OPERATION -> F.ArcSinh(value.toIExpr())
ExponentialOperations.ACOSH_OPERATION -> F.ArcCosh(value.toIExpr())
ExponentialOperations.ATANH_OPERATION -> F.ArcTanh(value.toIExpr())
PowerOperations.SQRT_OPERATION -> F.Sqrt(value.toIExpr())
ExponentialOperations.EXP_OPERATION -> F.Exp(value.toIExpr())
ExponentialOperations.LN_OPERATION -> F.Log(value.toIExpr())
else -> error("Unary operation $operation not defined in $this")
}
is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
else -> error("Binary operation $operation not defined in $this")
}
}

View File

@ -42,6 +42,7 @@ include(
":kmath-kotlingrad", ":kmath-kotlingrad",
":kmath-tensors", ":kmath-tensors",
":kmath-jupyter", ":kmath-jupyter",
":kmath-symja",
":examples", ":examples",
":benchmarks" ":benchmarks"
) )