Merge pull request #350 from mipt-npm/commandertvis/kotlingrad

Rename `DifferentiableMstExpression`, fix #349
This commit is contained in:
Alexander Nozik 2021-05-21 09:42:44 +03:00 committed by GitHub
commit bb228e0c68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 16 deletions

View File

@ -34,6 +34,7 @@
- Replace `MST.Symbolic` by `Symbol`, `Symbol` now implements MST
- Remove Any restriction on polynomials
- Add `out` variance to type parameters of `StructureND` and its implementations where possible
- Rename `DifferentiableMstExpression` to `KotlingradExpression`
### Deprecated

View File

@ -11,6 +11,7 @@ allprojects {
isAllowInsecureProtocol = true
}
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
maven("https://oss.sonatype.org/content/repositories/snapshots")
mavenCentral()
}

View File

@ -3,14 +3,16 @@ plugins {
id("ru.mipt.npm.gradle.common")
}
description = "Kotlin∇ integration module"
dependencies {
api("com.github.breandan:kaliningraph:0.1.6")
api("com.github.breandan:kotlingrad:0.4.5")
api(project(":kmath-ast"))
api(project(":kmath-core"))
testImplementation(project(":kmath-ast"))
}
readme {
description = "Functions, integration and interpolation"
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))

View File

@ -11,24 +11,24 @@ import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.NumericAlgebra
/**
* Represents [MST] based [DifferentiableExpression].
* Represents [MST] based [DifferentiableExpression] relying on [Kotlin](https://github.com/breandan/kotlingrad).
*
* The principle of this API is converting the [mst] to an [SFun], differentiating it with
* [Kotlin](https://github.com/breandan/kotlingrad), then converting [SFun] back to [MST].
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin, then converting
* [SFun] back to [MST].
*
* @param T The type of number.
* @param A The [NumericAlgebra] of [T].
* @property algebra The [A] instance.
* @property mst The [MST] node.
*/
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A,
public val mst: MST,
) : DifferentiableExpression<T, DifferentiableMstExpression<T, A>> {
) : DifferentiableExpression<T, KotlingradExpression<T, A>> {
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(
public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
KotlingradExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstNumericAlgebra::bindSymbol)
@ -39,7 +39,7 @@ public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
}
/**
* Wraps this [MST] into [DifferentiableMstExpression].
* Wraps this [MST] into [KotlingradExpression].
*/
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(algebra, this)
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> =
KotlingradExpression(algebra, this)

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.*
/**
* Maps [SVar] to [MST.Symbolic] directly.
* Maps [SVar] to [Symbol] directly.
*
* @receiver The variable.
* @returnAa node.
@ -81,7 +81,7 @@ public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
/**
* Maps [MST.Symbolic] to [SVar] directly.
* Maps [Symbol] to [SVar] directly.
*
* @receiver The node.
* @return A new variable.
@ -94,7 +94,7 @@ internal fun <X : SFun<X>> Symbol.toSVar(): SVar<X> = SVar(identity)
* Detailed mapping is:
*
* - [MST.Numeric] -> [SConst];
* - [MST.Symbolic] -> [SVar];
* - [Symbol] -> [SVar];
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
* - [MST.Binary] -> [Sum], [Prod], [Power].
*
@ -114,6 +114,12 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
ExponentialOperations.SINH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / 2.0 }.toSFun()
ExponentialOperations.COSH_OPERATION -> MstExtendedField { (exp(value) + exp(-value)) / 2.0 }.toSFun()
ExponentialOperations.TANH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / (exp(-value) + exp(value)) }.toSFun()
ExponentialOperations.ASINH_OPERATION -> MstExtendedField { ln(sqrt(value * value + one) + value) }.toSFun()
ExponentialOperations.ACOSH_OPERATION -> MstExtendedField { ln(value + sqrt((value - one) * (value + one))) }.toSFun()
ExponentialOperations.ATANH_OPERATION -> MstExtendedField { (ln(value + one) - ln(one - value)) / 2.0 }.toSFun()
else -> error("Unary operation $operation not defined in $this")
}

View File

@ -3,6 +3,8 @@ plugins {
id("ru.mipt.npm.gradle.common")
}
description = "ND4J NDStructure implementation and according NDAlgebra classes"
dependencies {
api(project(":kmath-tensors"))
api("org.nd4j:nd4j-api:1.0.0-beta7")
@ -12,7 +14,6 @@ dependencies {
}
readme {
description = "ND4J NDStructure implementation and according NDAlgebra classes"
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }