forked from kscience/kmath
Integration between MST and Symja IExpr
This commit is contained in:
parent
bb228e0c68
commit
a722cf0cdb
@ -2,18 +2,19 @@
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
### Added
|
### Added
|
||||||
- ScaleOperations interface
|
- `ScaleOperations` interface
|
||||||
- Field extends ScaleOperations
|
- `Field` extends `ScaleOperations`
|
||||||
- Basic integration API
|
- Basic integration API
|
||||||
- Basic MPP distributions and samplers
|
- Basic MPP distributions and samplers
|
||||||
- bindSymbolOrNull
|
- `bindSymbolOrNull`
|
||||||
- Blocking chains and Statistics
|
- Blocking chains and Statistics
|
||||||
- Multiplatform integration
|
- Multiplatform integration
|
||||||
- Integration for any Field element
|
- Integration for any Field element
|
||||||
- Extended operations for ND4J fields
|
- Extended operations for ND4J fields
|
||||||
- Jupyter Notebook integration module (kmath-jupyter)
|
- Jupyter Notebook integration module (kmath-jupyter)
|
||||||
- `@PerformancePitfall` annotation to mark possibly slow API
|
- `@PerformancePitfall` annotation to mark possibly slow API
|
||||||
- BigInt operation performance improvement and fixes by @zhelenskiy (#328)
|
- `BigInt` operation performance improvement and fixes by @zhelenskiy (#328)
|
||||||
|
- Integration between `MST` and Symja `IExpr`
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Exponential operations merged with hyperbolic functions
|
- Exponential operations merged with hyperbolic functions
|
||||||
|
@ -26,7 +26,7 @@ dependencies {
|
|||||||
implementation(project(":kmath-ejml"))
|
implementation(project(":kmath-ejml"))
|
||||||
implementation(project(":kmath-nd4j"))
|
implementation(project(":kmath-nd4j"))
|
||||||
implementation(project(":kmath-tensors"))
|
implementation(project(":kmath-tensors"))
|
||||||
|
implementation(project(":kmath-symja"))
|
||||||
implementation(project(":kmath-for-real"))
|
implementation(project(":kmath-for-real"))
|
||||||
|
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||||
|
@ -5,25 +5,23 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
import space.kscience.kmath.asm.compileToExpression
|
|
||||||
import space.kscience.kmath.expressions.derivative
|
import space.kscience.kmath.expressions.derivative
|
||||||
import space.kscience.kmath.expressions.invoke
|
import space.kscience.kmath.expressions.invoke
|
||||||
import space.kscience.kmath.expressions.symbol
|
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||||
import space.kscience.kmath.kotlingrad.toDiffExpression
|
import space.kscience.kmath.expressions.toExpression
|
||||||
|
import space.kscience.kmath.kotlingrad.toKotlingradExpression
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
||||||
* valid derivative.
|
* valid derivative in a certain point.
|
||||||
*/
|
*/
|
||||||
fun main() {
|
fun main() {
|
||||||
val x by symbol
|
val actualDerivative = "x^2-4*x-44"
|
||||||
|
.parseMath()
|
||||||
val actualDerivative = "x^2-4*x-44".parseMath()
|
.toKotlingradExpression(DoubleField)
|
||||||
.toDiffExpression(DoubleField)
|
|
||||||
.derivative(x)
|
.derivative(x)
|
||||||
|
|
||||||
|
val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField)
|
||||||
val expectedDerivative = "2*x-4".parseMath().compileToExpression(DoubleField)
|
check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
|
||||||
assert(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.Symbol.Companion.x
|
||||||
|
import space.kscience.kmath.expressions.derivative
|
||||||
|
import space.kscience.kmath.expressions.invoke
|
||||||
|
import space.kscience.kmath.expressions.toExpression
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.symja.toSymjaExpression
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In this example, x^2-4*x-44 function is differentiated with Symja, and the autodiff result is compared with
|
||||||
|
* valid derivative in a certain point.
|
||||||
|
*/
|
||||||
|
fun main() {
|
||||||
|
val actualDerivative = "x^2-4*x-44"
|
||||||
|
.parseMath()
|
||||||
|
.toSymjaExpression(DoubleField)
|
||||||
|
.derivative(x)
|
||||||
|
|
||||||
|
val expectedDerivative = "2*x-4".parseMath().toExpression(DoubleField)
|
||||||
|
check(actualDerivative(x to 123.0) == expectedDerivative(x to 123.0))
|
||||||
|
}
|
@ -39,7 +39,7 @@ public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps this [MST] into [KotlingradExpression].
|
* Wraps this [MST] into [KotlingradExpression] in the context of [algebra].
|
||||||
*/
|
*/
|
||||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> =
|
public fun <T : Number, A : NumericAlgebra<T>> MST.toKotlingradExpression(algebra: A): KotlingradExpression<T, A> =
|
||||||
KotlingradExpression(algebra, this)
|
KotlingradExpression(algebra, this)
|
||||||
|
21
kmath-symja/build.gradle.kts
Normal file
21
kmath-symja/build.gradle.kts
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
plugins {
|
||||||
|
kotlin("jvm")
|
||||||
|
id("ru.mipt.npm.gradle.common")
|
||||||
|
}
|
||||||
|
|
||||||
|
description = "Symja integration module"
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
api("org.matheclipse:matheclipse-core:2.0.0-SNAPSHOT")
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
testImplementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
|
}
|
||||||
|
|
||||||
|
readme {
|
||||||
|
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.symja
|
||||||
|
|
||||||
|
import org.matheclipse.core.eval.ExprEvaluator
|
||||||
|
import org.matheclipse.core.expression.F
|
||||||
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import space.kscience.kmath.expressions.MST
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.expressions.interpret
|
||||||
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
public class SymjaExpression<T : Number, A : NumericAlgebra<T>>(
|
||||||
|
public val algebra: A,
|
||||||
|
public val mst: MST,
|
||||||
|
public val evaluator: ExprEvaluator = DEFAULT_EVALUATOR,
|
||||||
|
) : DifferentiableExpression<T, SymjaExpression<T, A>> {
|
||||||
|
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||||
|
|
||||||
|
public override fun derivativeOrNull(symbols: List<Symbol>): SymjaExpression<T, A> = SymjaExpression(
|
||||||
|
algebra,
|
||||||
|
symbols.map(Symbol::toIExpr).fold(mst.toIExpr(), F::D).toMst(evaluator),
|
||||||
|
evaluator,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [MST] into [SymjaExpression] in the context of [algebra].
|
||||||
|
*/
|
||||||
|
public fun <T : Number, A : NumericAlgebra<T>> MST.toSymjaExpression(algebra: A): SymjaExpression<T, A> =
|
||||||
|
SymjaExpression(algebra, this)
|
@ -0,0 +1,95 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.symja
|
||||||
|
|
||||||
|
import org.matheclipse.core.eval.ExprEvaluator
|
||||||
|
import org.matheclipse.core.expression.ComplexNum
|
||||||
|
import org.matheclipse.core.expression.F
|
||||||
|
import org.matheclipse.core.interfaces.IExpr
|
||||||
|
import org.matheclipse.core.interfaces.ISymbol
|
||||||
|
import space.kscience.kmath.expressions.MST
|
||||||
|
import space.kscience.kmath.expressions.MstExtendedField
|
||||||
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.operations.*
|
||||||
|
|
||||||
|
internal val DEFAULT_EVALUATOR = ExprEvaluator(false, 100)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matches the given [IExpr] instance to appropriate [MST] node or evaluates it with [evaluator].
|
||||||
|
*/
|
||||||
|
public fun IExpr.toMst(evaluator: ExprEvaluator = DEFAULT_EVALUATOR): MST = MstExtendedField {
|
||||||
|
when {
|
||||||
|
isPlus -> first().toMst(evaluator) + second().toMst(evaluator)
|
||||||
|
isSin -> sin(first().toMst(evaluator))
|
||||||
|
isSinh -> sinh(first().toMst(evaluator))
|
||||||
|
isCos -> cos(first().toMst(evaluator))
|
||||||
|
isCosh -> cosh(first().toMst(evaluator))
|
||||||
|
isTan -> tan(first().toMst(evaluator))
|
||||||
|
isTanh -> tanh(first().toMst(evaluator))
|
||||||
|
isArcSin -> asin(first().toMst(evaluator))
|
||||||
|
isArcCos -> acos(first().toMst(evaluator))
|
||||||
|
isArcTan -> atan(first().toMst(evaluator))
|
||||||
|
isArcTanh -> atanh(first().toMst(evaluator))
|
||||||
|
isE -> bindSymbol("e")
|
||||||
|
isPi -> bindSymbol("pi")
|
||||||
|
isTimes -> first().toMst(evaluator) * second().toMst(evaluator)
|
||||||
|
isOne -> one
|
||||||
|
isZero -> zero
|
||||||
|
isImaginaryUnit -> bindSymbol("i")
|
||||||
|
isMinusOne -> -one
|
||||||
|
this@toMst is ISymbol -> bindSymbol(symbolName)
|
||||||
|
isPower -> power(first().toMst(evaluator), evaluator.evalf(second()))
|
||||||
|
isExp -> exp(first().toMst(evaluator))
|
||||||
|
isNumber -> number(evaluator.evalf(this@toMst))
|
||||||
|
this@toMst === F.NIL -> error("NIL cannot be converted to MST")
|
||||||
|
else -> evaluator.eval(this@toMst.toString()).toMst(evaluator)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matches the given [MST] instance to appropriate [IExpr] node, only standard operations and symbols (which are
|
||||||
|
* present in, say, [MstExtendedField]) are supported.
|
||||||
|
*/
|
||||||
|
public fun MST.toIExpr(): IExpr = when (this) {
|
||||||
|
is MST.Numeric -> F.symjify(value)
|
||||||
|
|
||||||
|
is Symbol -> when (identity) {
|
||||||
|
"e" -> F.E
|
||||||
|
"pi" -> F.Pi
|
||||||
|
"i" -> ComplexNum.I
|
||||||
|
else -> F.Dummy(identity)
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Unary -> when (operation) {
|
||||||
|
GroupOperations.PLUS_OPERATION -> value.toIExpr()
|
||||||
|
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr())
|
||||||
|
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
|
||||||
|
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
|
||||||
|
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
|
||||||
|
TrigonometricOperations.ASIN_OPERATION -> F.ArcSin(value.toIExpr())
|
||||||
|
TrigonometricOperations.ACOS_OPERATION -> F.ArcCos(value.toIExpr())
|
||||||
|
TrigonometricOperations.ATAN_OPERATION -> F.ArcTan(value.toIExpr())
|
||||||
|
ExponentialOperations.SINH_OPERATION -> F.Sinh(value.toIExpr())
|
||||||
|
ExponentialOperations.COSH_OPERATION -> F.Cosh(value.toIExpr())
|
||||||
|
ExponentialOperations.TANH_OPERATION -> F.Tanh(value.toIExpr())
|
||||||
|
ExponentialOperations.ASINH_OPERATION -> F.ArcSinh(value.toIExpr())
|
||||||
|
ExponentialOperations.ACOSH_OPERATION -> F.ArcCosh(value.toIExpr())
|
||||||
|
ExponentialOperations.ATANH_OPERATION -> F.ArcTanh(value.toIExpr())
|
||||||
|
PowerOperations.SQRT_OPERATION -> F.Sqrt(value.toIExpr())
|
||||||
|
ExponentialOperations.EXP_OPERATION -> F.Exp(value.toIExpr())
|
||||||
|
ExponentialOperations.LN_OPERATION -> F.Log(value.toIExpr())
|
||||||
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Binary -> when (operation) {
|
||||||
|
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
|
||||||
|
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
|
||||||
|
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
|
||||||
|
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
|
||||||
|
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
|
||||||
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
}
|
@ -42,6 +42,7 @@ include(
|
|||||||
":kmath-kotlingrad",
|
":kmath-kotlingrad",
|
||||||
":kmath-tensors",
|
":kmath-tensors",
|
||||||
":kmath-jupyter",
|
":kmath-jupyter",
|
||||||
|
":kmath-symja",
|
||||||
":examples",
|
":examples",
|
||||||
":benchmarks"
|
":benchmarks"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user