2021-07-19 03:12:41 +07:00

46 lines
1.7 KiB
Kotlin

/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.api.SFun
import edu.umontreal.kotlingrad.api.SVar
import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.NumericAlgebra
/**
* Represents [MST] based [DifferentiableExpression] relying on [Kotlin∇](https://github.com/breandan/kotlingrad).
*
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
* [SFun] back to [MST].
*
* @param T The type of number.
* @param A The [NumericAlgebra] of [T].
* @property algebra The [A] instance.
* @property mst The [MST] node.
*/
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
public val mst: MST,
) : SpecialDifferentiableExpression<T, KotlingradExpression<T, A>> {
override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
KotlingradExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstNumericAlgebra::bindSymbol)
.map<Symbol, SVar<KMathNumber<T, A>>>(Symbol::toSVar)
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* Wraps this [MST] into [KotlingradExpression] in the context of [algebra].
*/
public fun <T : Number, A : NumericAlgebra<T>> MST.toKotlingradExpression(algebra: A): KotlingradExpression<T, A> =
KotlingradExpression(algebra, this)