Merge pull request #350 from mipt-npm/commandertvis/kotlingrad

Rename `DifferentiableMstExpression`, fix #349
This commit is contained in:
Alexander Nozik 2021-05-21 09:42:44 +03:00 committed by GitHub
commit bb228e0c68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 27 additions and 16 deletions

View File

@ -34,6 +34,7 @@
- Replace `MST.Symbolic` by `Symbol`, `Symbol` now implements MST - Replace `MST.Symbolic` by `Symbol`, `Symbol` now implements MST
- Remove Any restriction on polynomials - Remove Any restriction on polynomials
- Add `out` variance to type parameters of `StructureND` and its implementations where possible - Add `out` variance to type parameters of `StructureND` and its implementations where possible
- Rename `DifferentiableMstExpression` to `KotlingradExpression`
### Deprecated ### Deprecated

View File

@ -11,6 +11,7 @@ allprojects {
isAllowInsecureProtocol = true isAllowInsecureProtocol = true
} }
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven") maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
maven("https://oss.sonatype.org/content/repositories/snapshots")
mavenCentral() mavenCentral()
} }

View File

@ -3,14 +3,16 @@ plugins {
id("ru.mipt.npm.gradle.common") id("ru.mipt.npm.gradle.common")
} }
description = "Kotlin∇ integration module"
dependencies { dependencies {
api("com.github.breandan:kaliningraph:0.1.6") api("com.github.breandan:kaliningraph:0.1.6")
api("com.github.breandan:kotlingrad:0.4.5") api("com.github.breandan:kotlingrad:0.4.5")
api(project(":kmath-ast")) api(project(":kmath-core"))
testImplementation(project(":kmath-ast"))
} }
readme { readme {
description = "Functions, integration and interpolation"
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))

View File

@ -11,24 +11,24 @@ import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.NumericAlgebra import space.kscience.kmath.operations.NumericAlgebra
/** /**
* Represents [MST] based [DifferentiableExpression]. * Represents [MST] based [DifferentiableExpression] relying on [Kotlin](https://github.com/breandan/kotlingrad).
* *
* The principle of this API is converting the [mst] to an [SFun], differentiating it with * The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin, then converting
* [Kotlin](https://github.com/breandan/kotlingrad), then converting [SFun] back to [MST]. * [SFun] back to [MST].
* *
* @param T The type of number. * @param T The type of number.
* @param A The [NumericAlgebra] of [T]. * @param A The [NumericAlgebra] of [T].
* @property algebra The [A] instance. * @property algebra The [A] instance.
* @property mst The [MST] node. * @property mst The [MST] node.
*/ */
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>( public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
public val algebra: A, public val algebra: A,
public val mst: MST, public val mst: MST,
) : DifferentiableExpression<T, DifferentiableMstExpression<T, A>> { ) : DifferentiableExpression<T, KotlingradExpression<T, A>> {
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments) public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> = public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
DifferentiableMstExpression( KotlingradExpression(
algebra, algebra,
symbols.map(Symbol::identity) symbols.map(Symbol::identity)
.map(MstNumericAlgebra::bindSymbol) .map(MstNumericAlgebra::bindSymbol)
@ -39,7 +39,7 @@ public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
} }
/** /**
* Wraps this [MST] into [DifferentiableMstExpression]. * Wraps this [MST] into [KotlingradExpression].
*/ */
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> = public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> =
DifferentiableMstExpression(algebra, this) KotlingradExpression(algebra, this)

View File

@ -14,7 +14,7 @@ import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
/** /**
* Maps [SVar] to [MST.Symbolic] directly. * Maps [SVar] to [Symbol] directly.
* *
* @receiver The variable. * @receiver The variable.
* @returnAa node. * @returnAa node.
@ -81,7 +81,7 @@ public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value) public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
/** /**
* Maps [MST.Symbolic] to [SVar] directly. * Maps [Symbol] to [SVar] directly.
* *
* @receiver The node. * @receiver The node.
* @return A new variable. * @return A new variable.
@ -94,7 +94,7 @@ internal fun <X : SFun<X>> Symbol.toSVar(): SVar<X> = SVar(identity)
* Detailed mapping is: * Detailed mapping is:
* *
* - [MST.Numeric] -> [SConst]; * - [MST.Numeric] -> [SConst];
* - [MST.Symbolic] -> [SVar]; * - [Symbol] -> [SVar];
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log]; * - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
* - [MST.Binary] -> [Sum], [Prod], [Power]. * - [MST.Binary] -> [Sum], [Prod], [Power].
* *
@ -114,6 +114,12 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun()) PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun()) ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln() ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
ExponentialOperations.SINH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / 2.0 }.toSFun()
ExponentialOperations.COSH_OPERATION -> MstExtendedField { (exp(value) + exp(-value)) / 2.0 }.toSFun()
ExponentialOperations.TANH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / (exp(-value) + exp(value)) }.toSFun()
ExponentialOperations.ASINH_OPERATION -> MstExtendedField { ln(sqrt(value * value + one) + value) }.toSFun()
ExponentialOperations.ACOSH_OPERATION -> MstExtendedField { ln(value + sqrt((value - one) * (value + one))) }.toSFun()
ExponentialOperations.ATANH_OPERATION -> MstExtendedField { (ln(value + one) - ln(one - value)) / 2.0 }.toSFun()
else -> error("Unary operation $operation not defined in $this") else -> error("Unary operation $operation not defined in $this")
} }

View File

@ -3,6 +3,8 @@ plugins {
id("ru.mipt.npm.gradle.common") id("ru.mipt.npm.gradle.common")
} }
description = "ND4J NDStructure implementation and according NDAlgebra classes"
dependencies { dependencies {
api(project(":kmath-tensors")) api(project(":kmath-tensors"))
api("org.nd4j:nd4j-api:1.0.0-beta7") api("org.nd4j:nd4j-api:1.0.0-beta7")
@ -12,7 +14,6 @@ dependencies {
} }
readme { readme {
description = "ND4J NDStructure implementation and according NDAlgebra classes"
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" } feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }