forked from kscience/kmath
Rename DifferentiableMstExpression, fix #349
This commit is contained in:
parent
958788bc91
commit
485e90fcd0
@ -34,6 +34,7 @@
|
|||||||
- Replace `MST.Symbolic` by `Symbol`, `Symbol` now implements MST
|
- Replace `MST.Symbolic` by `Symbol`, `Symbol` now implements MST
|
||||||
- Remove Any restriction on polynomials
|
- Remove Any restriction on polynomials
|
||||||
- Add `out` variance to type parameters of `StructureND` and its implementations where possible
|
- Add `out` variance to type parameters of `StructureND` and its implementations where possible
|
||||||
|
- Rename `DifferentiableMstExpression` to `KotlingradExpression`
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ allprojects {
|
|||||||
isAllowInsecureProtocol = true
|
isAllowInsecureProtocol = true
|
||||||
}
|
}
|
||||||
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
|
maven("https://maven.pkg.jetbrains.space/public/p/kotlinx-html/maven")
|
||||||
|
maven("https://oss.sonatype.org/content/repositories/snapshots")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,14 +3,16 @@ plugins {
|
|||||||
id("ru.mipt.npm.gradle.common")
|
id("ru.mipt.npm.gradle.common")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
description = "Kotlin∇ integration module"
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("com.github.breandan:kaliningraph:0.1.6")
|
api("com.github.breandan:kaliningraph:0.1.6")
|
||||||
api("com.github.breandan:kotlingrad:0.4.5")
|
api("com.github.breandan:kotlingrad:0.4.5")
|
||||||
api(project(":kmath-ast"))
|
api(project(":kmath-core"))
|
||||||
|
testImplementation(project(":kmath-ast"))
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
description = "Functions, integration and interpolation"
|
|
||||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
||||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
|
|
||||||
|
@ -11,24 +11,24 @@ import space.kscience.kmath.expressions.*
|
|||||||
import space.kscience.kmath.operations.NumericAlgebra
|
import space.kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [MST] based [DifferentiableExpression].
|
* Represents [MST] based [DifferentiableExpression] relying on [Kotlin∇](https://github.com/breandan/kotlingrad).
|
||||||
*
|
*
|
||||||
* The principle of this API is converting the [mst] to an [SFun], differentiating it with
|
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
|
||||||
* [Kotlin∇](https://github.com/breandan/kotlingrad), then converting [SFun] back to [MST].
|
* [SFun] back to [MST].
|
||||||
*
|
*
|
||||||
* @param T The type of number.
|
* @param T The type of number.
|
||||||
* @param A The [NumericAlgebra] of [T].
|
* @param A The [NumericAlgebra] of [T].
|
||||||
* @property algebra The [A] instance.
|
* @property algebra The [A] instance.
|
||||||
* @property mst The [MST] node.
|
* @property mst The [MST] node.
|
||||||
*/
|
*/
|
||||||
public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
|
public class KotlingradExpression<T : Number, A : NumericAlgebra<T>>(
|
||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
public val mst: MST,
|
public val mst: MST,
|
||||||
) : DifferentiableExpression<T, DifferentiableMstExpression<T, A>> {
|
) : DifferentiableExpression<T, KotlingradExpression<T, A>> {
|
||||||
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
public override fun invoke(arguments: Map<Symbol, T>): T = mst.interpret(algebra, arguments)
|
||||||
|
|
||||||
public override fun derivativeOrNull(symbols: List<Symbol>): DifferentiableMstExpression<T, A> =
|
public override fun derivativeOrNull(symbols: List<Symbol>): KotlingradExpression<T, A> =
|
||||||
DifferentiableMstExpression(
|
KotlingradExpression(
|
||||||
algebra,
|
algebra,
|
||||||
symbols.map(Symbol::identity)
|
symbols.map(Symbol::identity)
|
||||||
.map(MstNumericAlgebra::bindSymbol)
|
.map(MstNumericAlgebra::bindSymbol)
|
||||||
@ -39,7 +39,7 @@ public class DifferentiableMstExpression<T : Number, A : NumericAlgebra<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps this [MST] into [DifferentiableMstExpression].
|
* Wraps this [MST] into [KotlingradExpression].
|
||||||
*/
|
*/
|
||||||
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): DifferentiableMstExpression<T, A> =
|
public fun <T : Number, A : NumericAlgebra<T>> MST.toDiffExpression(algebra: A): KotlingradExpression<T, A> =
|
||||||
DifferentiableMstExpression(algebra, this)
|
KotlingradExpression(algebra, this)
|
@ -14,7 +14,7 @@ import space.kscience.kmath.expressions.Symbol
|
|||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps [SVar] to [MST.Symbolic] directly.
|
* Maps [SVar] to [Symbol] directly.
|
||||||
*
|
*
|
||||||
* @receiver The variable.
|
* @receiver The variable.
|
||||||
* @returnAa node.
|
* @returnAa node.
|
||||||
@ -81,7 +81,7 @@ public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
|||||||
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps [MST.Symbolic] to [SVar] directly.
|
* Maps [Symbol] to [SVar] directly.
|
||||||
*
|
*
|
||||||
* @receiver The node.
|
* @receiver The node.
|
||||||
* @return A new variable.
|
* @return A new variable.
|
||||||
@ -94,7 +94,7 @@ internal fun <X : SFun<X>> Symbol.toSVar(): SVar<X> = SVar(identity)
|
|||||||
* Detailed mapping is:
|
* Detailed mapping is:
|
||||||
*
|
*
|
||||||
* - [MST.Numeric] -> [SConst];
|
* - [MST.Numeric] -> [SConst];
|
||||||
* - [MST.Symbolic] -> [SVar];
|
* - [Symbol] -> [SVar];
|
||||||
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
|
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
|
||||||
* - [MST.Binary] -> [Sum], [Prod], [Power].
|
* - [MST.Binary] -> [Sum], [Prod], [Power].
|
||||||
*
|
*
|
||||||
@ -114,6 +114,12 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
|||||||
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
||||||
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||||
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||||
|
ExponentialOperations.SINH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / 2.0 }.toSFun()
|
||||||
|
ExponentialOperations.COSH_OPERATION -> MstExtendedField { (exp(value) + exp(-value)) / 2.0 }.toSFun()
|
||||||
|
ExponentialOperations.TANH_OPERATION -> MstExtendedField { (exp(value) - exp(-value)) / (exp(-value) + exp(value)) }.toSFun()
|
||||||
|
ExponentialOperations.ASINH_OPERATION -> MstExtendedField { ln(sqrt(value * value + one) + value) }.toSFun()
|
||||||
|
ExponentialOperations.ACOSH_OPERATION -> MstExtendedField { ln(value + sqrt((value - one) * (value + one))) }.toSFun()
|
||||||
|
ExponentialOperations.ATANH_OPERATION -> MstExtendedField { (ln(value + one) - ln(one - value)) / 2.0 }.toSFun()
|
||||||
else -> error("Unary operation $operation not defined in $this")
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
|
@ -3,6 +3,8 @@ plugins {
|
|||||||
id("ru.mipt.npm.gradle.common")
|
id("ru.mipt.npm.gradle.common")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-tensors"))
|
api(project(":kmath-tensors"))
|
||||||
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
||||||
@ -12,7 +14,6 @@ dependencies {
|
|||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
|
||||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
||||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }
|
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }
|
||||||
|
Loading…
Reference in New Issue
Block a user