Integration between MST and Symja IExpr
This commit is contained in:
parent
bb228e0c68
commit
a722cf0cdb
@ -2,18 +2,19 @@
|
||||
|
||||
## [Unreleased]
|
||||
### Added
|
||||
- ScaleOperations interface
|
||||
- Field extends ScaleOperations
|
||||
- `ScaleOperations` interface
|
||||
- `Field` extends `ScaleOperations`
|
||||
- Basic integration API
|
||||
- Basic MPP distributions and samplers
|
||||
- bindSymbolOrNull
|
||||
- `bindSymbolOrNull`
|
||||
- Blocking chains and Statistics
|
||||
- Multiplatform integration
|
||||
- Integration for any Field element
|
||||
- Extended operations for ND4J fields
|
||||
- Jupyter Notebook integration module (kmath-jupyter)
|
||||
- `@PerformancePitfall` annotation to mark possibly slow API
|
||||
- BigInt operation performance improvement and fixes by @zhelenskiy (#328)
|
||||
- `BigInt` operation performance improvement and fixes by @zhelenskiy (#328)
|
||||
- Integration between `MST` and Symja `IExpr`
|
||||
|
||||
### Changed
|
||||
- Exponential operations merged with hyperbolic functions
|
||||
|
@ -26,7 +26,7 @@ dependencies {
|
||||
implementation(project(":kmath-ejml"))
|
||||
implementation(project(":kmath-nd4j"))
|
||||
implementation(project(":kmath-tensors"))
|
||||
|
||||
implementation(project(":kmath-symja"))
|
||||
implementation(project(":kmath-for-real"))
|
||||
|
||||
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||
|
@ -5,25 +5,23 @@
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.asm.compileToExpression
|
||||
import space.kscience.kmath.expressions.derivative
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.expressions.symbol
|
||||
import space.kscience.kmath.kotlingrad.toDiffExpression
|
||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||
import space.kscience.kmath.expressions.toExpression
|
||||
import space.kscience.kmath.kotlingrad.toKotlingradExpression
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
|
||||
/**
|
||||
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
||||
* valid derivative.
|
||||
* valid derivative in a certain point.
|
||||
*/
|
||||
fun main() {
|
||||
val x by symbol
|
||||
|
||||
val actualDerivative = "x^2-4*x-44".parseMath()
|
||||
.toDiffExpression(DoubleField)
|
||||
val actualDerivative = "x^2-4*x-44"
|
||||
.parseMath()
|
||||
.toKotlingradExpression(DoubleField)
|
||||
.derivative(x)
|
||||
|
||||
|
||||
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
|
||||
assert(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
|
||||
val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField)
|
||||
check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
|
||||
}
|
||||
|
@ -0,0 +1,27 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.ast
|
||||
|
||||
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||
import space.kscience.kmath.expressions.derivative
|
||||
import space.kscience.kmath.expressions.invoke
|
||||
import space.kscience.kmath.expressions.toExpression
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.symja.toSymjaExpression
|
||||
|
||||
/**
|
||||
* In this example, x^2-4*x-44 function is differentiated with Symja, and the autodiff result is compared with
|
||||
* valid derivative in a certain point.
|
||||
*/
|
||||
fun main() {
|
||||
val actualDerivative = "x^2-4*x-44"
|
||||
.parseMath()
|
||||
.toSymjaExpression(DoubleField)
|
||||
.derivative(x)
|
||||
|
||||
val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField)
|
||||
check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
|
||||
}
|
@ -39,7 +39,7 @@ public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [MST] into [KotlingradExpression].
|
||||
* Wraps this [MST] into [KotlingradExpression] in the context of [algebra].
|
||||
*/
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> =
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toKotlingradExpression(algebra: A): KotlingradExpression<T, A> =
|
||||
KotlingradExpression(algebra, this)
|
||||
|
21
kmath-symja/build.gradle.kts
Normal file
21
kmath-symja/build.gradle.kts
Normal file
@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
plugins {
|
||||
kotlin("jvm")
|
||||
id("ru.mipt.npm.gradle.common")
|
||||
}
|
||||
|
||||
description = "Symja integration module"
|
||||
|
||||
dependencies {
|
||||
api("org.matheclipse:matheclipse-core:2.0.0-SNAPSHOT")
|
||||
api(project(":kmath-core"))
|
||||
testImplementation("org.slf4j:slf4j-simple:1.7.30")
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.symja
|
||||
|
||||
import org.matheclipse.core.eval.ExprEvaluator
|
||||
import org.matheclipse.core.expression.F
|
||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.expressions.interpret
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
public class SymjaExpression<T : Number, A : NumericAlgebra<T>>(
|
||||
public val algebra: A,
|
||||
public val mst: MST,
|
||||
public val evaluator: ExprEvaluator = DEFAULT_EVALUATOR,
|
||||
) : DifferentiableExpression<T, SymjaExpression<T, A>> {
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): SymjaExpression<T, A> = SymjaExpression(
|
||||
algebra,
|
||||
symbols.map(Symbol::toIExpr).fold(mst.toIExpr(), F::D).toMst(evaluator),
|
||||
evaluator,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [MST] into [SymjaExpression] in the context of [algebra].
|
||||
*/
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toSymjaExpression(algebra: A): SymjaExpression<T, A> =
|
||||
SymjaExpression(algebra, this)
|
@ -0,0 +1,95 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.symja
|
||||
|
||||
import org.matheclipse.core.eval.ExprEvaluator
|
||||
import org.matheclipse.core.expression.ComplexNum
|
||||
import org.matheclipse.core.expression.F
|
||||
import org.matheclipse.core.interfaces.IExpr
|
||||
import org.matheclipse.core.interfaces.ISymbol
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.expressions.MstExtendedField
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.operations.*
|
||||
|
||||
internal val DEFAULT_EVALUATOR = ExprEvaluator(false, 100)
|
||||
|
||||
/**
|
||||
* Matches the given [IExpr] instance to appropriate [MST] node or evaluates it with [evaluator].
|
||||
*/
|
||||
public fun IExpr.toMst(evaluator: ExprEvaluator = DEFAULT_EVALUATOR): MST = MstExtendedField {
|
||||
when {
|
||||
isPlus -> first().toMst(evaluator) + second().toMst(evaluator)
|
||||
isSin -> sin(first().toMst(evaluator))
|
||||
isSinh -> sinh(first().toMst(evaluator))
|
||||
isCos -> cos(first().toMst(evaluator))
|
||||
isCosh -> cosh(first().toMst(evaluator))
|
||||
isTan -> tan(first().toMst(evaluator))
|
||||
isTanh -> tanh(first().toMst(evaluator))
|
||||
isArcSin -> asin(first().toMst(evaluator))
|
||||
isArcCos -> acos(first().toMst(evaluator))
|
||||
isArcTan -> atan(first().toMst(evaluator))
|
||||
isArcTanh -> atanh(first().toMst(evaluator))
|
||||
isE -> bindSymbol("e")
|
||||
isPi -> bindSymbol("pi")
|
||||
isTimes -> first().toMst(evaluator) * second().toMst(evaluator)
|
||||
isOne -> one
|
||||
isZero -> zero
|
||||
isImaginaryUnit -> bindSymbol("i")
|
||||
isMinusOne -> -one
|
||||
this@toMst is ISymbol -> bindSymbol(symbolName)
|
||||
isPower -> power(first().toMst(evaluator), evaluator.evalf(second()))
|
||||
isExp -> exp(first().toMst(evaluator))
|
||||
isNumber -> number(evaluator.evalf(this@toMst))
|
||||
this@toMst === F.NIL -> error("NIL cannot be converted to MST")
|
||||
else -> evaluator.eval(this@toMst.toString()).toMst(evaluator)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Matches the given [MST] instance to appropriate [IExpr] node, only standard operations and symbols (which are
|
||||
* present in, say, [MstExtendedField]) are supported.
|
||||
*/
|
||||
public fun MST.toIExpr(): IExpr = when (this) {
|
||||
is MST.Numeric -> F.symjify(value)
|
||||
|
||||
is Symbol -> when (identity) {
|
||||
"e" -> F.E
|
||||
"pi" -> F.Pi
|
||||
"i" -> ComplexNum.I
|
||||
else -> F.Dummy(identity)
|
||||
}
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
GroupOperations.PLUS_OPERATION -> value.toIExpr()
|
||||
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr())
|
||||
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
|
||||
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
|
||||
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
|
||||
TrigonometricOperations.ASIN_OPERATION -> F.ArcSin(value.toIExpr())
|
||||
TrigonometricOperations.ACOS_OPERATION -> F.ArcCos(value.toIExpr())
|
||||
TrigonometricOperations.ATAN_OPERATION -> F.ArcTan(value.toIExpr())
|
||||
ExponentialOperations.SINH_OPERATION -> F.Sinh(value.toIExpr())
|
||||
ExponentialOperations.COSH_OPERATION -> F.Cosh(value.toIExpr())
|
||||
ExponentialOperations.TANH_OPERATION -> F.Tanh(value.toIExpr())
|
||||
ExponentialOperations.ASINH_OPERATION -> F.ArcSinh(value.toIExpr())
|
||||
ExponentialOperations.ACOSH_OPERATION -> F.ArcCosh(value.toIExpr())
|
||||
ExponentialOperations.ATANH_OPERATION -> F.ArcTanh(value.toIExpr())
|
||||
PowerOperations.SQRT_OPERATION -> F.Sqrt(value.toIExpr())
|
||||
ExponentialOperations.EXP_OPERATION -> F.Exp(value.toIExpr())
|
||||
ExponentialOperations.LN_OPERATION -> F.Log(value.toIExpr())
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
|
||||
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
|
||||
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
|
||||
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
|
||||
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
}
|
@ -42,6 +42,7 @@ include(
|
||||
":kmath-kotlingrad",
|
||||
":kmath-tensors",
|
||||
":kmath-jupyter",
|
||||
":kmath-symja",
|
||||
":examples",
|
||||
":benchmarks"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user