forked from kscience/kmath
Merge pull request #150 from mipt-npm/kotlingrad
Add adapters of scalar functions to MST and vice versa
This commit is contained in:
commit
abe68a4fb6
16
README.md
16
README.md
@ -211,7 +211,15 @@ Release artifacts are accessible from bintray with following configuration (see
|
|||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
repositories {
|
repositories {
|
||||||
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
@ -228,7 +236,15 @@ Development builds are uploaded to the separate repository:
|
|||||||
|
|
||||||
```kotlin
|
```kotlin
|
||||||
repositories {
|
repositories {
|
||||||
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import ru.mipt.npm.gradle.KSciencePublishPlugin
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("ru.mipt.npm.project")
|
id("ru.mipt.npm.project")
|
||||||
}
|
}
|
||||||
@ -9,9 +11,16 @@ internal val githubProject: String by extra("kmath")
|
|||||||
allprojects {
|
allprojects {
|
||||||
repositories {
|
repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
group = "kscience.kmath"
|
group = "kscience.kmath"
|
||||||
@ -19,7 +28,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>()
|
if (name.startsWith("kmath")) apply<KSciencePublishPlugin>()
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -8,18 +8,25 @@ plugins {
|
|||||||
}
|
}
|
||||||
|
|
||||||
allOpen.annotation("org.openjdk.jmh.annotations.State")
|
allOpen.annotation("org.openjdk.jmh.annotations.State")
|
||||||
|
sourceSets.register("benchmarks")
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
jcenter()
|
||||||
|
maven("https://clojars.org/repo")
|
||||||
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
|
maven("https://jitpack.io")
|
||||||
|
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
sourceSets.register("benchmarks")
|
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
|
implementation(project(":kmath-kotlingrad"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
|
@ -9,7 +9,7 @@ import kscience.kmath.operations.RealField
|
|||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
class ExpressionsInterpretersBenchmark {
|
internal class ExpressionsInterpretersBenchmark {
|
||||||
private val algebra: Field<Double> = RealField
|
private val algebra: Field<Double> = RealField
|
||||||
fun functionalExpression() {
|
fun functionalExpression() {
|
||||||
val expr = algebra.expressionInField {
|
val expr = algebra.expressionInField {
|
||||||
@ -47,6 +47,16 @@ class ExpressionsInterpretersBenchmark {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
|
||||||
|
* core FunctionalExpressions API.
|
||||||
|
*
|
||||||
|
* The expected rating is:
|
||||||
|
*
|
||||||
|
* 1. ASM.
|
||||||
|
* 2. MST.
|
||||||
|
* 3. FE.
|
||||||
|
*/
|
||||||
fun main() {
|
fun main() {
|
||||||
val benchmark = ExpressionsInterpretersBenchmark()
|
val benchmark = ExpressionsInterpretersBenchmark()
|
||||||
|
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
package kscience.kmath.ast
|
||||||
|
|
||||||
|
import kscience.kmath.asm.compile
|
||||||
|
import kscience.kmath.expressions.derivative
|
||||||
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.expressions.symbol
|
||||||
|
import kscience.kmath.kotlingrad.differentiable
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
||||||
|
* valid derivative.
|
||||||
|
*/
|
||||||
|
fun main() {
|
||||||
|
val x by symbol
|
||||||
|
|
||||||
|
val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath())
|
||||||
|
.differentiable()
|
||||||
|
.derivative(x)
|
||||||
|
.compile()
|
||||||
|
|
||||||
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||||
|
}
|
@ -6,14 +6,14 @@ import kscience.kmath.operations.*
|
|||||||
* [Algebra] over [MST] nodes.
|
* [Algebra] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstAlgebra : NumericAlgebra<MST> {
|
public object MstAlgebra : NumericAlgebra<MST> {
|
||||||
override fun number(value: Number): MST = MST.Numeric(value)
|
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST =
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||||
MST.Unary(operation, arg)
|
MST.Unary(operation, arg)
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MST.Binary(operation, left, right)
|
MST.Binary(operation, left, right)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra<MST> {
|
|||||||
* [Space] over [MST] nodes.
|
* [Space] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST = number(0.0)
|
override val zero: MST.Numeric by lazy { number(0.0) }
|
||||||
|
|
||||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstAlgebra.binaryOperation(operation, left, right)
|
MstAlgebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Ring] over [MST] nodes.
|
* [Ring] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||||
override val zero: MST
|
override val zero: MST.Numeric
|
||||||
get() = MstSpace.zero
|
get() = MstSpace.zero
|
||||||
override val one: MST = number(1.0)
|
|
||||||
|
|
||||||
override fun number(value: Number): MST = MstSpace.number(value)
|
override val one: MST.Numeric by lazy { number(1.0) }
|
||||||
override fun symbol(value: String): MST = MstSpace.symbol(value)
|
|
||||||
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
|
||||||
|
|
||||||
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||||
|
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||||
|
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||||
|
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||||
|
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
|
||||||
MstSpace.binaryOperation(operation, left, right)
|
MstSpace.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [Field] over [MST] nodes.
|
* [Field] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstField : Field<MST> {
|
public object MstField : Field<MST> {
|
||||||
public override val zero: MST
|
public override val zero: MST.Numeric
|
||||||
get() = MstRing.zero
|
get() = MstRing.zero
|
||||||
|
|
||||||
public override val one: MST
|
public override val one: MST.Numeric
|
||||||
get() = MstRing.one
|
get() = MstRing.one
|
||||||
|
|
||||||
public override fun symbol(value: String): MST = MstRing.symbol(value)
|
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
|
||||||
public override fun number(value: Number): MST = MstRing.number(value)
|
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||||
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||||
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||||
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||||
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||||
|
|
||||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstRing.binaryOperation(operation, left, right)
|
MstRing.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [ExtendedField] over [MST] nodes.
|
* [ExtendedField] over [MST] nodes.
|
||||||
*/
|
*/
|
||||||
public object MstExtendedField : ExtendedField<MST> {
|
public object MstExtendedField : ExtendedField<MST> {
|
||||||
override val zero: MST
|
override val zero: MST.Numeric
|
||||||
get() = MstField.zero
|
get() = MstField.zero
|
||||||
|
|
||||||
override val one: MST
|
override val one: MST.Numeric
|
||||||
get() = MstField.one
|
get() = MstField.one
|
||||||
|
|
||||||
override fun symbol(value: String): MST = MstField.symbol(value)
|
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||||
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
||||||
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||||
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||||
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||||
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||||
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||||
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||||
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||||
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||||
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||||
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||||
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||||
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||||
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||||
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||||
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||||
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||||
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
|
||||||
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
|
||||||
|
|
||||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||||
|
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||||
|
|
||||||
|
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||||
|
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||||
|
|
||||||
|
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||||
MstField.binaryOperation(operation, left, right)
|
MstField.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ import kotlin.contracts.contract
|
|||||||
* @property mst the [MST] node.
|
* @property mst the [MST] node.
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||||
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
algebra.binaryOperation(operation, left, right)
|
algebra.binaryOperation(operation, left, right)
|
||||||
|
|
||||||
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
@Suppress("UNCHECKED_CAST")
|
||||||
algebra.number(value)
|
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||||
|
(algebra as NumericAlgebra<T>).number(value)
|
||||||
else
|
else
|
||||||
error("Numeric nodes are not supported by $this")
|
error("Numeric nodes are not supported by $this")
|
||||||
}
|
}
|
||||||
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
|||||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||||
mstAlgebra: E,
|
mstAlgebra: E,
|
||||||
block: E.() -> MST,
|
block: E.() -> MST,
|
||||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds [MstExpression] over [Space].
|
* Builds [MstExpression] over [Space].
|
||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstSpace.block())
|
return MstExpression(this, MstSpace.block())
|
||||||
}
|
}
|
||||||
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstRing.block())
|
return MstExpression(this, MstRing.block())
|
||||||
}
|
}
|
||||||
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstField.block())
|
return MstExpression(this, MstField.block())
|
||||||
}
|
}
|
||||||
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
|
|||||||
*
|
*
|
||||||
* @author Iaroslav Postovalov
|
* @author Iaroslav Postovalov
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return MstExpression(this, MstExtendedField.block())
|
return MstExpression(this, MstExtendedField.block())
|
||||||
}
|
}
|
||||||
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInSpace(block)
|
return algebra.mstInSpace(block)
|
||||||
}
|
}
|
||||||
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInRing(block)
|
return algebra.mstInRing(block)
|
||||||
}
|
}
|
||||||
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik
|
* @author Alexander Nozik
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInField(block)
|
return algebra.mstInField(block)
|
||||||
}
|
}
|
||||||
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
|||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||||
block: MstExtendedField.() -> MST,
|
block: MstExtendedField.() -> MST,
|
||||||
): MstExpression<T> {
|
): MstExpression<T, A> {
|
||||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||||
return algebra.mstInExtendedField(block)
|
return algebra.mstInExtendedField(block)
|
||||||
}
|
}
|
||||||
|
@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
|
|||||||
*
|
*
|
||||||
* @author Alexander Nozik.
|
* @author Alexander Nozik.
|
||||||
*/
|
*/
|
||||||
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra)
|
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||||
|
mst.compileWith(T::class.java, algebra)
|
||||||
|
@ -95,10 +95,10 @@ public class DerivativeStructureField(
|
|||||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||||
|
|
||||||
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
public companion object :
|
||||||
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
|
||||||
return DerivativeStructureExpression(function)
|
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
|
||||||
}
|
DerivativeStructureExpression(function)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ public class DerivativeStructureField(
|
|||||||
*/
|
*/
|
||||||
public class DerivativeStructureExpression(
|
public class DerivativeStructureExpression(
|
||||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||||
) : DifferentiableExpression<Double> {
|
) : DifferentiableExpression<Double, Expression<Double>> {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
DerivativeStructureField(0, arguments).function().value
|
DerivativeStructureField(0, arguments).function().value
|
||||||
|
|
||||||
|
@ -19,9 +19,8 @@ import kotlin.reflect.KClass
|
|||||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
public operator fun PointValuePair.component2(): Double = value
|
public operator fun PointValuePair.component2(): Double = value
|
||||||
|
|
||||||
public class CMOptimizationProblem(
|
public class CMOptimizationProblem(override val symbols: List<Symbol>, ) :
|
||||||
override val symbols: List<Symbol>,
|
OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||||
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
|
||||||
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||||
@ -49,7 +48,7 @@ public class CMOptimizationProblem(
|
|||||||
addOptimizationData(objectiveFunction)
|
addOptimizationData(objectiveFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
|
public override fun diffExpression(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
||||||
expression(expression)
|
expression(expression)
|
||||||
val gradientFunction = ObjectiveFunctionGradient {
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
|
@ -12,7 +12,6 @@ import kscience.kmath.structures.asBuffer
|
|||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
@ -21,7 +20,7 @@ public fun Fitting.chiSquared(
|
|||||||
y: Buffer<Double>,
|
y: Buffer<Double>,
|
||||||
yErr: Buffer<Double>,
|
yErr: Buffer<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
@ -31,7 +30,7 @@ public fun Fitting.chiSquared(
|
|||||||
y: Iterable<Double>,
|
y: Iterable<Double>,
|
||||||
yErr: Iterable<Double>,
|
yErr: Iterable<Double>,
|
||||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||||
): DifferentiableExpression<Double> = chiSquared(
|
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
|
||||||
DerivativeStructureField,
|
DerivativeStructureField,
|
||||||
x.toList().asBuffer(),
|
x.toList().asBuffer(),
|
||||||
y.toList().asBuffer(),
|
y.toList().asBuffer(),
|
||||||
@ -39,7 +38,6 @@ public fun Fitting.chiSquared(
|
|||||||
model
|
model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize expression without derivatives
|
* Optimize expression without derivatives
|
||||||
*/
|
*/
|
||||||
@ -48,16 +46,15 @@ public fun Expression<Double>.optimize(
|
|||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
*/
|
*/
|
||||||
public fun DifferentiableExpression<Double>.optimize(
|
public fun DifferentiableExpression<Double, Expression<Double>>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimizationProblem.() -> Unit,
|
configuration: CMOptimizationProblem.() -> Unit,
|
||||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||||
|
|
||||||
public fun DifferentiableExpression<Double>.minimize(
|
public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
|
||||||
vararg startPoint: Pair<Symbol, Double>,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
configuration: CMOptimizationProblem.() -> Unit = {},
|
configuration: CMOptimizationProblem.() -> Unit = {},
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
|
@ -47,14 +47,17 @@ internal class OptimizeTest {
|
|||||||
val sigma = 1.0
|
val sigma = 1.0
|
||||||
val generator = Distribution.normal(0.0, sigma)
|
val generator = Distribution.normal(0.0, sigma)
|
||||||
val chain = generator.sample(RandomGenerator.default(112667))
|
val chain = generator.sample(RandomGenerator.default(112667))
|
||||||
val x = (1..100).map { it.toDouble() }
|
val x = (1..100).map(Int::toDouble)
|
||||||
val y = x.map { it ->
|
|
||||||
|
val y = x.map {
|
||||||
it.pow(2) + it + 1 + chain.nextDouble()
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
}
|
}
|
||||||
val yErr = x.map { sigma }
|
|
||||||
val chi2 = Fitting.chiSquared(x, y, yErr) { x ->
|
val yErr = List(x.size) { sigma }
|
||||||
|
|
||||||
|
val chi2 = Fitting.chiSquared(x, y, yErr) { x1 ->
|
||||||
val cWithDefault = bindOrNull(c) ?: one
|
val cWithDefault = bindOrNull(c) ?: one
|
||||||
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||||
|
@ -1,29 +1,40 @@
|
|||||||
package kscience.kmath.expressions
|
package kscience.kmath.expressions
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An expression that provides derivatives
|
* Represents expression which structure can be differentiated.
|
||||||
|
*
|
||||||
|
* @param T the type this expression takes as argument and returns.
|
||||||
|
* @param R the type of expression this expression can be differentiated to.
|
||||||
*/
|
*/
|
||||||
public interface DifferentiableExpression<T> : Expression<T> {
|
public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> {
|
||||||
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
|
/**
|
||||||
|
* Differentiates this expression by ordered collection of [symbols].
|
||||||
|
*
|
||||||
|
* @param symbols the symbols.
|
||||||
|
* @return the derivative or `null`.
|
||||||
|
*/
|
||||||
|
public fun derivativeOrNull(symbols: List<Symbol>): R?
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
|
||||||
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
|
||||||
derivative(symbols.toList())
|
derivative(symbols.toList())
|
||||||
|
|
||||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
|
||||||
derivative(StringSymbol(name))
|
derivative(StringSymbol(name))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A [DifferentiableExpression] that defines only first derivatives
|
* A [DifferentiableExpression] that defines only first derivatives
|
||||||
*/
|
*/
|
||||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T,R> {
|
||||||
|
/**
|
||||||
|
* Returns first derivative of this expression by given [symbol].
|
||||||
|
*/
|
||||||
|
public abstract fun derivativeOrNull(symbol: Symbol): R?
|
||||||
|
|
||||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
|
||||||
|
|
||||||
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
|
|
||||||
val dSymbol = symbols.firstOrNull() ?: return null
|
val dSymbol = symbols.firstOrNull() ?: return null
|
||||||
return derivativeOrNull(dSymbol)
|
return derivativeOrNull(dSymbol)
|
||||||
}
|
}
|
||||||
@ -32,6 +43,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
|||||||
/**
|
/**
|
||||||
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||||
*/
|
*/
|
||||||
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
|
||||||
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
public fun process(function: A.() -> I): DifferentiableExpression<T, R>
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,9 @@ public inline class StringSymbol(override val identity: String) : Symbol {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An elementary function that could be invoked on a map of arguments
|
* An elementary function that could be invoked on a map of arguments.
|
||||||
|
*
|
||||||
|
* @param T the type this expression takes as argument and returns.
|
||||||
*/
|
*/
|
||||||
public fun interface Expression<T> {
|
public fun interface Expression<T> {
|
||||||
/**
|
/**
|
||||||
|
@ -68,7 +68,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
|||||||
): DerivationResult<T> {
|
): DerivationResult<T> {
|
||||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
|
||||||
return SimpleAutoDiffField(this, bindings).derivate(body)
|
return SimpleAutoDiffField(this, bindings).differentiate(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||||
@ -83,12 +83,21 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
public val context: F,
|
public val context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||||
|
public override val zero: AutoDiffValue<T>
|
||||||
|
get() = const(context.zero)
|
||||||
|
|
||||||
|
public override val one: AutoDiffValue<T>
|
||||||
|
get() = const(context.one)
|
||||||
|
|
||||||
// this stack contains pairs of blocks and values to apply them to
|
// this stack contains pairs of blocks and values to apply them to
|
||||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||||
private var sp: Int = 0
|
private var sp: Int = 0
|
||||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||||
|
|
||||||
|
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||||
|
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||||
* with respect to this variable.
|
* with respect to this variable.
|
||||||
@ -106,11 +115,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
override fun hashCode(): Int = identity.hashCode()
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
public override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
|
||||||
|
|
||||||
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||||
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||||
@ -119,7 +124,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
private fun runBackwardPass() {
|
private fun runBackwardPass() {
|
||||||
while (sp > 0) {
|
while (sp > 0) {
|
||||||
@ -129,9 +133,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
|
||||||
|
|
||||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -165,7 +166,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||||
val result = function()
|
val result = function()
|
||||||
result.d = context.one // computing derivative w.r.t result
|
result.d = context.one // computing derivative w.r.t result
|
||||||
runBackwardPass()
|
runBackwardPass()
|
||||||
@ -174,41 +175,41 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
|
|
||||||
// Overloads for Double constants
|
// Overloads for Double constants
|
||||||
|
|
||||||
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
public override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
||||||
|
|
||||||
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||||
|
|
||||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
public override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||||
|
|
||||||
|
|
||||||
// Basic math (+, -, *, /)
|
// Basic math (+, -, *, /)
|
||||||
|
|
||||||
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value + b.value }) { z ->
|
derive(const { a.value + b.value }) { z ->
|
||||||
a.d += z.d
|
a.d += z.d
|
||||||
b.d += z.d
|
b.d += z.d
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value * b.value }) { z ->
|
derive(const { a.value * b.value }) { z ->
|
||||||
a.d += z.d * b.value
|
a.d += z.d * b.value
|
||||||
b.d += z.d * a.value
|
b.d += z.d * a.value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
derive(const { a.value / b.value }) { z ->
|
derive(const { a.value / b.value }) { z ->
|
||||||
a.d += z.d / b.value
|
a.d += z.d / b.value
|
||||||
b.d -= z.d * a.value / (b.value * b.value)
|
b.d -= z.d * a.value / (b.value * b.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
public override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
||||||
derive(const { k.toDouble() * a.value }) { z ->
|
derive(const { k.toDouble() * a.value }) { z ->
|
||||||
a.d += z.d * k.toDouble()
|
a.d += z.d * k.toDouble()
|
||||||
}
|
}
|
||||||
@ -220,15 +221,15 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
|||||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T>() {
|
) : FirstDerivativeExpression<T, Expression<T>>() {
|
||||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return SimpleAutoDiffField(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
public override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
|
||||||
derivationResult.derivative(symbol)
|
derivationResult.derivative(symbol)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -236,13 +237,10 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
|||||||
/**
|
/**
|
||||||
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>, Expression<T>> =
|
||||||
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
AutoDiffProcessor { function ->
|
||||||
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
SimpleAutoDiffExpression(field, function)
|
||||||
return SimpleAutoDiffExpression(field, function)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Extensions for differentiation of various basic mathematical functions
|
// Extensions for differentiation of various basic mathematical functions
|
||||||
|
|
||||||
@ -392,4 +390,4 @@ public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
|
|||||||
|
|
||||||
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||||
(this as SimpleAutoDiffField<T, F>).atanh(arg)
|
(this as SimpleAutoDiffField<T, F>).atanh(arg)
|
||||||
}
|
}
|
||||||
|
9
kmath-kotlingrad/build.gradle.kts
Normal file
9
kmath-kotlingrad/build.gradle.kts
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("com.github.breandan:kaliningraph:0.1.2")
|
||||||
|
implementation("com.github.breandan:kotlingrad:0.3.7")
|
||||||
|
api(project(":kmath-ast"))
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.SFun
|
||||||
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExpression
|
||||||
|
import kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import kscience.kmath.expressions.Symbol
|
||||||
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
|
||||||
|
*
|
||||||
|
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
|
||||||
|
* [SFun] back to [MST].
|
||||||
|
*
|
||||||
|
* @param T the type of number.
|
||||||
|
* @param A the [NumericAlgebra] of [T].
|
||||||
|
* @property expr the underlying [MstExpression].
|
||||||
|
*/
|
||||||
|
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
||||||
|
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
|
||||||
|
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The [MstExpression.algebra] of [expr].
|
||||||
|
*/
|
||||||
|
public val algebra: A
|
||||||
|
get() = expr.algebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The [MstExpression.mst] of [expr].
|
||||||
|
*/
|
||||||
|
public val mst: MST
|
||||||
|
get() = expr.mst
|
||||||
|
|
||||||
|
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||||
|
|
||||||
|
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
|
||||||
|
algebra,
|
||||||
|
symbols.map(Symbol::identity)
|
||||||
|
.map(MstAlgebra::symbol)
|
||||||
|
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||||
|
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||||
|
.toMst(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||||
|
*/
|
||||||
|
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
|
||||||
|
DifferentiableMstExpression(this)
|
@ -0,0 +1,18 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.RealNumber
|
||||||
|
import edu.umontreal.kotlingrad.experimental.SConst
|
||||||
|
import kscience.kmath.operations.NumericAlgebra
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
|
||||||
|
*
|
||||||
|
* @param T the type of number.
|
||||||
|
* @param A the [NumericAlgebra] of [T].
|
||||||
|
* @property algebra the algebra.
|
||||||
|
* @param value the value of this number.
|
||||||
|
*/
|
||||||
|
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||||
|
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||||
|
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||||
|
}
|
@ -0,0 +1,124 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
|
import kscience.kmath.ast.MST
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExtendedField
|
||||||
|
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [SVar] to [MST.Symbolic] directly.
|
||||||
|
*
|
||||||
|
* @receiver the variable.
|
||||||
|
* @return a node.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [SVar] to [MST.Numeric] directly.
|
||||||
|
*
|
||||||
|
* @receiver the constant.
|
||||||
|
* @return a node.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
|
||||||
|
* [Power] operation is limited to constant right-hand side arguments.
|
||||||
|
*
|
||||||
|
* Detailed mapping is:
|
||||||
|
*
|
||||||
|
* - [SVar] -> [MstExtendedField.symbol];
|
||||||
|
* - [SConst] -> [MstExtendedField.number];
|
||||||
|
* - [Sum] -> [MstExtendedField.add];
|
||||||
|
* - [Prod] -> [MstExtendedField.multiply];
|
||||||
|
* - [Power] -> [MstExtendedField.power] (limited to constant exponents only);
|
||||||
|
* - [Negative] -> [MstExtendedField.unaryMinus];
|
||||||
|
* - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right);
|
||||||
|
* - [Sine] -> [MstExtendedField.sin];
|
||||||
|
* - [Cosine] -> [MstExtendedField.cos];
|
||||||
|
* - [Tangent] -> [MstExtendedField.tan];
|
||||||
|
* - [DProd] is vector operation, and it is requested to be evaluated;
|
||||||
|
* - [SComposition] is also requested to be evaluated eagerly;
|
||||||
|
* - [VSumAll] is requested to be evaluated;
|
||||||
|
* - [Derivative] is requested to be evaluated.
|
||||||
|
*
|
||||||
|
* @receiver the scalar function.
|
||||||
|
* @return a node.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
||||||
|
when (this@toMst) {
|
||||||
|
is SVar -> toMst()
|
||||||
|
is SConst -> toMst()
|
||||||
|
is Sum -> left.toMst() + right.toMst()
|
||||||
|
is Prod -> left.toMst() * right.toMst()
|
||||||
|
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
|
||||||
|
is Negative -> -input.toMst()
|
||||||
|
is Log -> ln(left.toMst()) / ln(right.toMst())
|
||||||
|
is Sine -> sin(input.toMst())
|
||||||
|
is Cosine -> cos(input.toMst())
|
||||||
|
is Tangent -> tan(input.toMst())
|
||||||
|
is DProd -> this@toMst().toMst()
|
||||||
|
is SComposition -> this@toMst().toMst()
|
||||||
|
is VSumAll<X, *> -> this@toMst().toMst()
|
||||||
|
is Derivative -> this@toMst().toMst()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [MST.Numeric] to [SConst] directly.
|
||||||
|
*
|
||||||
|
* @receiver the node.
|
||||||
|
* @return a new constant.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [MST.Symbolic] to [SVar] directly.
|
||||||
|
*
|
||||||
|
* @receiver the node.
|
||||||
|
* @param proto the prototype instance.
|
||||||
|
* @return a new variable.
|
||||||
|
*/
|
||||||
|
internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||||
|
*
|
||||||
|
* Detailed mapping is:
|
||||||
|
*
|
||||||
|
* - [MST.Numeric] -> [SConst];
|
||||||
|
* - [MST.Symbolic] -> [SVar];
|
||||||
|
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
|
||||||
|
* - [MST.Binary] -> [Sum], [Prod], [Power].
|
||||||
|
*
|
||||||
|
* @receiver the node.
|
||||||
|
* @param proto the prototype instance.
|
||||||
|
* @return a scalar function.
|
||||||
|
*/
|
||||||
|
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||||
|
is MST.Numeric -> toSConst()
|
||||||
|
is MST.Symbolic -> toSVar()
|
||||||
|
|
||||||
|
is MST.Unary -> when (operation) {
|
||||||
|
SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
|
||||||
|
SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
|
||||||
|
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||||
|
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||||
|
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||||
|
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
||||||
|
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||||
|
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||||
|
else -> error("Unary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
|
||||||
|
is MST.Binary -> when (operation) {
|
||||||
|
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||||
|
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||||
|
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||||
|
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||||
|
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||||
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
package kscience.kmath.kotlingrad
|
||||||
|
|
||||||
|
import edu.umontreal.kotlingrad.experimental.*
|
||||||
|
import kscience.kmath.asm.compile
|
||||||
|
import kscience.kmath.ast.MstAlgebra
|
||||||
|
import kscience.kmath.ast.MstExpression
|
||||||
|
import kscience.kmath.ast.parseMath
|
||||||
|
import kscience.kmath.expressions.invoke
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertTrue
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
internal class AdaptingTests {
|
||||||
|
@Test
|
||||||
|
fun symbol() {
|
||||||
|
val c1 = MstAlgebra.symbol("x")
|
||||||
|
assertTrue(c1.toSVar<KMathNumber<Double, RealField>>().name == "x")
|
||||||
|
val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun number() {
|
||||||
|
val c1 = MstAlgebra.number(12354324)
|
||||||
|
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
||||||
|
val c2 = "0.234".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||||
|
val c3 = "1e-3".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun simpleFunctionShape() {
|
||||||
|
val linear = "2*x+16".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
if (linear !is Sum) fail()
|
||||||
|
if (linear.left !is Prod) fail()
|
||||||
|
if (linear.right !is SConst) fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun simpleFunctionDerivative() {
|
||||||
|
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
|
||||||
|
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||||
|
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||||
|
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun moreComplexDerivative() {
|
||||||
|
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
|
||||||
|
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||||
|
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
||||||
|
|
||||||
|
val expectedDerivative = MstExpression(
|
||||||
|
RealField,
|
||||||
|
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
|
||||||
|
).compile()
|
||||||
|
|
||||||
|
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,6 @@
|
|||||||
plugins { id("ru.mipt.npm.mpp") }
|
plugins {
|
||||||
|
id("ru.mipt.npm.mpp")
|
||||||
|
}
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
commonMain {
|
commonMain {
|
||||||
|
@ -12,16 +12,18 @@ public object Fitting {
|
|||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, I : Any, A> chiSquared(
|
public fun <T : Any, I : Any, A> chiSquared(
|
||||||
autoDiff: AutoDiffProcessor<T, I, A>,
|
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
x: Buffer<T>,
|
x: Buffer<T>,
|
||||||
y: Buffer<T>,
|
y: Buffer<T>,
|
||||||
yErr: Buffer<T>,
|
yErr: Buffer<T>,
|
||||||
model: A.(I) -> I,
|
model: A.(I) -> I,
|
||||||
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
return autoDiff.process {
|
return autoDiff.process {
|
||||||
var sum = zero
|
var sum = zero
|
||||||
|
|
||||||
x.indices.forEach {
|
x.indices.forEach {
|
||||||
val xValue = const(x[it])
|
val xValue = const(x[it])
|
||||||
val yValue = const(y[it])
|
val yValue = const(y[it])
|
||||||
@ -29,6 +31,7 @@ public object Fitting {
|
|||||||
val modelValue = model(xValue)
|
val modelValue = model(xValue)
|
||||||
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -45,6 +48,7 @@ public object Fitting {
|
|||||||
): Expression<Double> {
|
): Expression<Double> {
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
return Expression { arguments ->
|
return Expression { arguments ->
|
||||||
x.indices.sumByDouble {
|
x.indices.sumByDouble {
|
||||||
val xValue = x[it]
|
val xValue = x[it]
|
||||||
@ -56,4 +60,4 @@ public object Fitting {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,17 +27,17 @@ public interface OptimizationProblem<T : Any> {
|
|||||||
/**
|
/**
|
||||||
* Define the initial guess for the optimization problem
|
* Define the initial guess for the optimization problem
|
||||||
*/
|
*/
|
||||||
public fun initialGuess(map: Map<Symbol, T>): Unit
|
public fun initialGuess(map: Map<Symbol, T>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set an objective function expression
|
* Set an objective function expression
|
||||||
*/
|
*/
|
||||||
public fun expression(expression: Expression<T>): Unit
|
public fun expression(expression: Expression<T>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set a differentiable expression as objective function as function and gradient provider
|
* Set a differentiable expression as objective function as function and gradient provider
|
||||||
*/
|
*/
|
||||||
public fun diffExpression(expression: DifferentiableExpression<T>): Unit
|
public fun diffExpression(expression: DifferentiableExpression<T, Expression<T>>)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the problem from previous optimization run
|
* Update the problem from previous optimization run
|
||||||
@ -50,9 +50,8 @@ public interface OptimizationProblem<T : Any> {
|
|||||||
public fun optimize(): OptimizationResult<T>
|
public fun optimize(): OptimizationResult<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
public fun interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
||||||
public fun build(symbols: List<Symbol>): P
|
public fun build(symbols: List<Symbol>): P
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
||||||
@ -60,7 +59,6 @@ public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFac
|
|||||||
block: P.() -> Unit,
|
block: P.() -> Unit,
|
||||||
): P = build(symbols).apply(block)
|
): P = build(symbols).apply(block)
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
||||||
*/
|
*/
|
||||||
@ -78,7 +76,7 @@ public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
|
|||||||
/**
|
/**
|
||||||
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith(
|
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
||||||
factory: OptimizationProblemFactory<T, F>,
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: F.() -> Unit,
|
configuration: F.() -> Unit,
|
||||||
@ -88,4 +86,3 @@ public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.op
|
|||||||
problem.diffExpression(this)
|
problem.diffExpression(this)
|
||||||
return problem.optimize()
|
return problem.optimize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
plugins { id("ru.mipt.npm.jvm") }
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
description = "Binding for https://github.com/JetBrains-Research/viktor"
|
description = "Binding for https://github.com/JetBrains-Research/viktor"
|
||||||
|
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
pluginManagement {
|
pluginManagement {
|
||||||
repositories {
|
repositories {
|
||||||
mavenLocal()
|
|
||||||
jcenter()
|
|
||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
|
jcenter()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion = "0.6.4-dev-1.4.20-M2"
|
val toolsVersion = "0.6.4-dev-1.4.20-M2"
|
||||||
@ -41,5 +39,6 @@ include(
|
|||||||
":kmath-geometry",
|
":kmath-geometry",
|
||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":kmath-ejml",
|
":kmath-ejml",
|
||||||
|
":kmath-kotlingrad",
|
||||||
":examples"
|
":examples"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user