Add extension functions for DifferentiableMstExpression

This commit is contained in:
Iaroslav Postovalov 2020-10-30 16:40:43 +07:00
parent 29a670483b
commit bc4eb95ae7
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 12 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import kscience.kmath.asm.compile
import kscience.kmath.expressions.invoke
import kscience.kmath.expressions.symbol
import kscience.kmath.kotlingrad.DifferentiableMstExpression
import kscience.kmath.kotlingrad.derivative
import kscience.kmath.operations.RealField
/**
@ -14,7 +15,7 @@ fun main() {
val x by symbol
val actualDerivative = DifferentiableMstExpression(RealField, "x^2-4*x-44".parseMath())
.derivativeOrNull(listOf(x))
.derivative(x)
.compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()

View File

@ -5,6 +5,7 @@ import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.StringSymbol
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.NumericAlgebra
@ -46,6 +47,15 @@ public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpres
)
}
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(symbols: List<Symbol>): MstExpression<T, A> =
derivativeOrNull(symbols)
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(vararg symbols: Symbol): MstExpression<T, A> =
derivative(symbols.toList())
public fun <T : Number, A : NumericAlgebra<T>> DifferentiableMstExpression<T, A>.derivative(name: String): MstExpression<T, A> =
derivative(StringSymbol(name))
/**
* Wraps this [MstExpression] into [DifferentiableMstExpression].
*/