forked from kscience/kmath
Merge pull request #350 from mipt-npm/commandertvis/kotlingrad
Rename `DifferentiableMstExpression`, fix #349
This commit is contained in:
commit
bb228e0c68
@ -34,6 +34,7 @@
|
||||
- Replace `MST.Symbolic` by `Symbol`, `Symbol` now implements MST
|
||||
- Remove Any restriction on polynomials
|
||||
- Add `out` variance to type parameters of `StructureND` and its implementations where possible
|
||||
- Rename `DifferentiableMstExpression` to `KotlingradExpression`
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
@ -11,6 +11,7 @@ allprojects {
|
||||
isAllowInsecureProtocol = true
|
||||
}
|
||||
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
|
||||
maven("https://oss.sonatype.org/content/repositories/snapshots")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
|
@ -3,14 +3,16 @@ plugins {
|
||||
id("ru.mipt.npm.gradle.common")
|
||||
}
|
||||
|
||||
description = "Kotlin∇ integration module"
|
||||
|
||||
dependencies {
|
||||
api("com.github.breandan:kaliningraph:0.1.6")
|
||||
api("com.github.breandan:kotlingrad:0.4.5")
|
||||
api(project(":kmath-ast"))
|
||||
api(project(":kmath-core"))
|
||||
testImplementation(project(":kmath-ast"))
|
||||
}
|
||||
|
||||
readme {
|
||||
description = "Functions, integration and interpolation"
|
||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
|
@ -11,24 +11,24 @@ import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
/**
|
||||
* Represents [MST] based [DifferentiableExpression].
|
||||
* Represents [MST] based [DifferentiableExpression] relying on [Kotlin∇](https://github.com/breandan/kotlingrad).
|
||||
*
|
||||
* The principle of this API is converting the [mst] to an [SFun], differentiating it with
|
||||
* [Kotlin∇](https://github.com/breandan/kotlingrad), then converting [SFun] back to [MST].
|
||||
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
|
||||
* [SFun] back to [MST].
|
||||
*
|
||||
* @param T The type of number.
|
||||
* @param A The [NumericAlgebra] of [T].
|
||||
* @property algebra The [A] instance.
|
||||
* @property mst The [MST] node.
|
||||
*/
|
||||
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
|
||||
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
|
||||
public val algebra: A,
|
||||
public val mst: MST,
|
||||
) : DifferentiableExpression<T, DifferentiableMstExpression<T, A>> {
|
||||
) : DifferentiableExpression<T, KotlingradExpression<T, A>> {
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
|
||||
KotlingradExpression(
|
||||
algebra,
|
||||
symbols.map(Symbol::identity)
|
||||
.map(MstNumericAlgebra::bindSymbol)
|
||||
@ -39,7 +39,7 @@ public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [MST] into [DifferentiableMstExpression].
|
||||
* Wraps this [MST] into [KotlingradExpression].
|
||||
*/
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(algebra, this)
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> =
|
||||
KotlingradExpression(algebra, this)
|
@ -14,7 +14,7 @@ import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
* Maps [SVar] to [MST.Symbolic] directly.
|
||||
* Maps [SVar] to [Symbol] directly.
|
||||
*
|
||||
* @receiver The variable.
|
||||
* @returnAa node.
|
||||
@ -81,7 +81,7 @@ public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
||||
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||
|
||||
/**
|
||||
* Maps [MST.Symbolic] to [SVar] directly.
|
||||
* Maps [Symbol] to [SVar] directly.
|
||||
*
|
||||
* @receiver The node.
|
||||
* @return A new variable.
|
||||
@ -94,7 +94,7 @@ internal fun <X : SFun<X>> Symbol.toSVar(): SVar<X> = SVar(identity)
|
||||
* Detailed mapping is:
|
||||
*
|
||||
* - [MST.Numeric] -> [SConst];
|
||||
* - [MST.Symbolic] -> [SVar];
|
||||
* - [Symbol] -> [SVar];
|
||||
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
|
||||
* - [MST.Binary] -> [Sum], [Prod], [Power].
|
||||
*
|
||||
@ -114,6 +114,12 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
||||
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||
ExponentialOperations.SINH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / 2.0 }.toSFun()
|
||||
ExponentialOperations.COSH_OPERATION -> MstExtendedField { (exp(value) + exp(-value)) / 2.0 }.toSFun()
|
||||
ExponentialOperations.TANH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / (exp(-value) + exp(value)) }.toSFun()
|
||||
ExponentialOperations.ASINH_OPERATION -> MstExtendedField { ln(sqrt(value * value + one) + value) }.toSFun()
|
||||
ExponentialOperations.ACOSH_OPERATION -> MstExtendedField { ln(value + sqrt((value - one) * (value + one))) }.toSFun()
|
||||
ExponentialOperations.ATANH_OPERATION -> MstExtendedField { (ln(value + one) - ln(one - value)) / 2.0 }.toSFun()
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ plugins {
|
||||
id("ru.mipt.npm.gradle.common")
|
||||
}
|
||||
|
||||
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-tensors"))
|
||||
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
||||
@ -12,7 +14,6 @@ dependencies {
|
||||
}
|
||||
|
||||
readme {
|
||||
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }
|
||||
|
Loading…
Reference in New Issue
Block a user