forked from kscience/kmath
Add extension functions for DifferentiableMstExpression
This commit is contained in:
parent
29a670483b
commit
bc4eb95ae7
@ -4,6 +4,7 @@ import kscience.kmath.asm.compile
|
|||||||
import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
import kscience.kmath.expressions.symbol
|
import kscience.kmath.expressions.symbol
|
||||||
import kscience.kmath.kotlingrad.DifferentiableMstExpression
|
import kscience.kmath.kotlingrad.DifferentiableMstExpression
|
||||||
|
import kscience.kmath.kotlingrad.derivative
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -14,7 +15,7 @@ fun main() {
|
|||||||
val x by symbol
|
val x by symbol
|
||||||
|
|
||||||
val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath())
|
val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath())
|
||||||
.derivativeOrNull(listOf(x))
|
.derivative(x)
|
||||||
.compile()
|
.compile()
|
||||||
|
|
||||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
@ -5,6 +5,7 @@ import kscience.kmath.ast.MST
|
|||||||
import kscience.kmath.ast.MstAlgebra
|
import kscience.kmath.ast.MstAlgebra
|
||||||
import kscience.kmath.ast.MstExpression
|
import kscience.kmath.ast.MstExpression
|
||||||
import kscience.kmath.expressions.DifferentiableExpression
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.StringSymbol
|
||||||
import kscience.kmath.expressions.Symbol
|
import kscience.kmath.expressions.Symbol
|
||||||
import kscience.kmath.operations.NumericAlgebra
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
@ -46,6 +47,15 @@ public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpres
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(symbols: List<Symbol>): MstExpression<T, A> =
|
||||||
|
derivativeOrNull(symbols)
|
||||||
|
|
||||||
|
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(vararg symbols: Symbol): MstExpression<T, A> =
|
||||||
|
derivative(symbols.toList())
|
||||||
|
|
||||||
|
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(name: String): MstExpression<T, A> =
|
||||||
|
derivative(StringSymbol(name))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||||
*/
|
*/
|
||||||
|
Loading…
Reference in New Issue
Block a user