Add adapters of scalar functions to MST and vice versa #150
16
README.md
16
README.md
@ -211,7 +211,15 @@ Release artifacts are accessible from bintray with following configuration (see
|
||||
|
||||
```kotlin
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
maven("https://jitpack.io")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
@ -228,7 +236,15 @@ Development builds are uploaded to the separate repository:
|
||||
|
||||
```kotlin
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
maven("https://jitpack.io")
|
||||
mavenCentral()
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import ru.mipt.npm.gradle.KSciencePublishPlugin
|
||||
|
||||
plugins {
|
||||
id("ru.mipt.npm.project")
|
||||
}
|
||||
@ -9,9 +11,16 @@ internal val githubProject: String by extra("kmath")
|
||||
allprojects {
|
||||
repositories {
|
||||
jcenter()
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
maven("https://jitpack.io")
|
||||
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
group = "kscience.kmath"
|
||||
|
||||
@ -19,7 +28,7 @@ allprojects {
|
||||
}
|
||||
|
||||
subprojects {
|
||||
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>()
|
||||
if (name.startsWith("kmath")) apply<KSciencePublishPlugin>()
|
||||
}
|
||||
|
||||
readme {
|
||||
|
@ -8,18 +8,25 @@ plugins {
|
||||
}
|
||||
|
||||
allOpen.annotation("org.openjdk.jmh.annotations.State")
|
||||
sourceSets.register("benchmarks")
|
||||
|
||||
repositories {
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
jcenter()
|
||||
See the comment above. I think it makes sense to leave them here for people to see which repositories to use. See the comment above. I think it makes sense to leave them here for people to see which repositories to use.
|
||||
maven("https://clojars.org/repo")
|
||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
maven("https://jitpack.io")
|
||||
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
sourceSets.register("benchmarks")
|
||||
|
||||
dependencies {
|
||||
implementation(project(":kmath-ast"))
|
||||
implementation(project(":kmath-kotlingrad"))
|
||||
implementation(project(":kmath-core"))
|
||||
implementation(project(":kmath-coroutines"))
|
||||
implementation(project(":kmath-commons"))
|
||||
|
@ -9,7 +9,7 @@ import kscience.kmath.operations.RealField
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
class ExpressionsInterpretersBenchmark {
|
||||
internal class ExpressionsInterpretersBenchmark {
|
||||
private val algebra: Field<Double> = RealField
|
||||
fun functionalExpression() {
|
||||
val expr = algebra.expressionInField {
|
||||
@ -47,6 +47,16 @@ class ExpressionsInterpretersBenchmark {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
|
||||
* core FunctionalExpressions API.
|
||||
*
|
||||
* The expected rating is:
|
||||
*
|
||||
* 1. ASM.
|
||||
* 2. MST.
|
||||
* 3. FE.
|
||||
*/
|
||||
fun main() {
|
||||
val benchmark = ExpressionsInterpretersBenchmark()
|
||||
|
||||
|
@ -0,0 +1,24 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.expressions.derivative
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.expressions.symbol
|
||||
import kscience.kmath.kotlingrad.differentiable
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* In this example, x^2-4*x-44 function is differentiated with Kotlin∇, and the autodiff result is compared with
|
||||
* valid derivative.
|
||||
*/
|
||||
fun main() {
|
||||
val x by symbol
|
||||
|
||||
val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath())
|
||||
.differentiable()
|
||||
.derivative(x)
|
||||
.compile()
|
||||
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
|
||||
}
|
@ -6,14 +6,14 @@ import kscience.kmath.operations.*
|
||||
* [Algebra] over [MST] nodes.
|
||||
*/
|
||||
public object MstAlgebra : NumericAlgebra<MST> {
|
||||
override fun number(value: Number): MST = MST.Numeric(value)
|
||||
override fun number(value: Number): MST.Numeric = MST.Numeric(value)
|
||||
|
||||
override fun symbol(value: String): MST = MST.Symbolic(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST =
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary =
|
||||
MST.Unary(operation, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MST.Binary(operation, left, right)
|
||||
}
|
||||
|
||||
@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra<MST> {
|
||||
* [Space] over [MST] nodes.
|
||||
*/
|
||||
public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST = number(0.0)
|
||||
override val zero: MST.Numeric by lazy { number(0.0) }
|
||||
|
||||
override fun number(value: Number): MST = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstAlgebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Ring] over [MST] nodes.
|
||||
*/
|
||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
override val zero: MST
|
||||
override val zero: MST.Numeric
|
||||
get() = MstSpace.zero
|
||||
override val one: MST = number(1.0)
|
||||
|
||||
override fun number(value: Number): MST = MstSpace.number(value)
|
||||
override fun symbol(value: String): MST = MstSpace.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
|
||||
override val one: MST.Numeric by lazy { number(1.0) }
|
||||
|
||||
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k)
|
||||
override fun number(value: Number): MST.Numeric = MstSpace.number(value)
|
||||
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstSpace.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg)
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* [Field] over [MST] nodes.
|
||||
*/
|
||||
public object MstField : Field<MST> {
|
||||
public override val zero: MST
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstRing.zero
|
||||
|
||||
public override val one: MST
|
||||
public override val one: MST.Numeric
|
||||
get() = MstRing.one
|
||||
|
||||
public override fun symbol(value: String): MST = MstRing.symbol(value)
|
||||
public override fun number(value: Number): MST = MstRing.number(value)
|
||||
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
|
||||
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
|
||||
|
||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstRing.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg)
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [MST] nodes.
|
||||
*/
|
||||
public object MstExtendedField : ExtendedField<MST> {
|
||||
override val zero: MST
|
||||
override val zero: MST.Numeric
|
||||
get() = MstField.zero
|
||||
|
||||
override val one: MST
|
||||
override val one: MST.Numeric
|
||||
get() = MstField.one
|
||||
|
||||
override fun symbol(value: String): MST = MstField.symbol(value)
|
||||
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||
override fun add(a: MST, b: MST): MST = MstField.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b)
|
||||
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||
override fun number(value: Number): MST.Numeric = MstField.number(value)
|
||||
override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
|
||||
override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
|
||||
override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
|
||||
override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
|
||||
override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
|
||||
override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
|
||||
override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
|
||||
override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
|
||||
override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
|
||||
override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
|
||||
override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
|
||||
override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||
override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
|
||||
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
|
||||
|
||||
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
|
||||
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
|
||||
|
||||
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
|
||||
MstField.binaryOperation(operation, left, right)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg)
|
||||
override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ import kotlin.contracts.contract
|
||||
* @property mst the [MST] node.
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> {
|
||||
public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
|
||||
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
|
||||
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
|
||||
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
|
||||
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
||||
override fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
algebra.binaryOperation(operation, left, right)
|
||||
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra)
|
||||
algebra.number(value)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
|
||||
(algebra as NumericAlgebra<T>).number(value)
|
||||
else
|
||||
error("Numeric nodes are not supported by $this")
|
||||
}
|
||||
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
|
||||
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
|
||||
mstAlgebra: E,
|
||||
block: E.() -> MST,
|
||||
): MstExpression<T> = MstExpression(this, mstAlgebra.block())
|
||||
): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
|
||||
|
||||
/**
|
||||
* Builds [MstExpression] over [Space].
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstSpace.block())
|
||||
}
|
||||
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstRing.block())
|
||||
}
|
||||
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstField.block())
|
||||
}
|
||||
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
|
||||
*
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return MstExpression(this, MstExtendedField.block())
|
||||
}
|
||||
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInSpace(block)
|
||||
}
|
||||
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInRing(block)
|
||||
}
|
||||
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> {
|
||||
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInField(block)
|
||||
}
|
||||
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
|
||||
*/
|
||||
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
|
||||
block: MstExtendedField.() -> MST,
|
||||
): MstExpression<T> {
|
||||
): MstExpression<T, A> {
|
||||
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
|
||||
return algebra.mstInExtendedField(block)
|
||||
}
|
||||
|
@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
|
||||
*
|
||||
* @author Alexander Nozik.
|
||||
*/
|
||||
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra)
|
||||
public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
|
||||
mst.compileWith(T::class.java, algebra)
|
||||
|
@ -95,10 +95,10 @@ public class DerivativeStructureField(
|
||||
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||
|
||||
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> {
|
||||
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> {
|
||||
return DerivativeStructureExpression(function)
|
||||
}
|
||||
public companion object :
|
||||
AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
|
||||
public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
|
||||
DerivativeStructureExpression(function)
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,7 +108,7 @@ public class DerivativeStructureField(
|
||||
*/
|
||||
public class DerivativeStructureExpression(
|
||||
public val function: DerivativeStructureField.() -> DerivativeStructure,
|
||||
) : DifferentiableExpression<Double> {
|
||||
) : DifferentiableExpression<Double, Expression<Double>> {
|
||||
public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||
DerivativeStructureField(0, arguments).function().value
|
||||
|
||||
|
@ -19,9 +19,8 @@ import kotlin.reflect.KClass
|
||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||
public operator fun PointValuePair.component2(): Double = value
|
||||
|
||||
public class CMOptimizationProblem(
|
||||
override val symbols: List<Symbol>,
|
||||
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||
public class CMOptimizationProblem(override val symbols: List<Symbol>, ) :
|
||||
OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
|
||||
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
|
||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
|
||||
@ -49,7 +48,7 @@ public class CMOptimizationProblem(
|
||||
addOptimizationData(objectiveFunction)
|
||||
}
|
||||
|
||||
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit {
|
||||
public override fun diffExpression(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
||||
expression(expression)
|
||||
val gradientFunction = ObjectiveFunctionGradient {
|
||||
val args = it.toMap()
|
||||
|
@ -12,7 +12,6 @@ import kscience.kmath.structures.asBuffer
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
||||
|
||||
|
||||
/**
|
||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||
*/
|
||||
@ -21,7 +20,7 @@ public fun Fitting.chiSquared(
|
||||
y: Buffer<Double>,
|
||||
yErr: Buffer<Double>,
|
||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
||||
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
|
||||
|
||||
/**
|
||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||
@ -31,7 +30,7 @@ public fun Fitting.chiSquared(
|
||||
y: Iterable<Double>,
|
||||
yErr: Iterable<Double>,
|
||||
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
|
||||
): DifferentiableExpression<Double> = chiSquared(
|
||||
): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
|
||||
DerivativeStructureField,
|
||||
x.toList().asBuffer(),
|
||||
y.toList().asBuffer(),
|
||||
@ -39,7 +38,6 @@ public fun Fitting.chiSquared(
|
||||
model
|
||||
)
|
||||
|
||||
|
||||
/**
|
||||
* Optimize expression without derivatives
|
||||
*/
|
||||
@ -48,16 +46,15 @@ public fun Expression<Double>.optimize(
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||
|
||||
|
||||
/**
|
||||
* Optimize differentiable expression
|
||||
*/
|
||||
public fun DifferentiableExpression<Double>.optimize(
|
||||
public fun DifferentiableExpression<Double, Expression<Double>>.optimize(
|
||||
vararg symbols: Symbol,
|
||||
configuration: CMOptimizationProblem.() -> Unit,
|
||||
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
|
||||
|
||||
public fun DifferentiableExpression<Double>.minimize(
|
||||
public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
|
||||
vararg startPoint: Pair<Symbol, Double>,
|
||||
configuration: CMOptimizationProblem.() -> Unit = {},
|
||||
): OptimizationResult<Double> {
|
||||
|
@ -47,14 +47,17 @@ internal class OptimizeTest {
|
||||
val sigma = 1.0
|
||||
val generator = Distribution.normal(0.0, sigma)
|
||||
val chain = generator.sample(RandomGenerator.default(112667))
|
||||
val x = (1..100).map { it.toDouble() }
|
||||
val y = x.map { it ->
|
||||
val x = (1..100).map(Int::toDouble)
|
||||
|
||||
val y = x.map {
|
||||
it.pow(2) + it + 1 + chain.nextDouble()
|
||||
}
|
||||
val yErr = x.map { sigma }
|
||||
val chi2 = Fitting.chiSquared(x, y, yErr) { x ->
|
||||
|
||||
val yErr = List(x.size) { sigma }
|
||||
|
||||
val chi2 = Fitting.chiSquared(x, y, yErr) { x1 ->
|
||||
val cWithDefault = bindOrNull(c) ?: one
|
||||
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
||||
bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault
|
||||
}
|
||||
|
||||
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)
|
||||
|
@ -1,29 +1,40 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
/**
|
||||
* An expression that provides derivatives
|
||||
* Represents expression which structure can be differentiated.
|
||||
*
|
||||
* @param T the type this expression takes as argument and returns.
|
||||
* @param R the type of expression this expression can be differentiated to.
|
||||
*/
|
||||
public interface DifferentiableExpression<T> : Expression<T> {
|
||||
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>?
|
||||
public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> {
|
||||
/**
|
||||
* Differentiates this expression by ordered collection of [symbols].
|
||||
*
|
||||
* @param symbols the symbols.
|
||||
* @return the derivative or `null`.
|
||||
*/
|
||||
public fun derivativeOrNull(symbols: List<Symbol>): R?
|
||||
}
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> =
|
||||
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
|
||||
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> =
|
||||
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
|
||||
derivative(symbols.toList())
|
||||
|
||||
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> =
|
||||
public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
|
||||
derivative(StringSymbol(name))
|
||||
|
||||
/**
|
||||
* A [DifferentiableExpression] that defines only first derivatives
|
||||
*/
|
||||
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> {
|
||||
public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T,R> {
|
||||
/**
|
||||
* Returns first derivative of this expression by given [symbol].
|
||||
*/
|
||||
public abstract fun derivativeOrNull(symbol: Symbol): R?
|
||||
|
||||
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>?
|
||||
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
|
||||
public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
|
||||
val dSymbol = symbols.firstOrNull() ?: return null
|
||||
return derivativeOrNull(dSymbol)
|
||||
}
|
||||
@ -32,6 +43,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
|
||||
/**
|
||||
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
|
||||
*/
|
||||
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> {
|
||||
public fun process(function: A.() -> I): DifferentiableExpression<T>
|
||||
public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
|
||||
public fun process(function: A.() -> I): DifferentiableExpression<T, R>
|
||||
}
|
@ -22,7 +22,9 @@ public inline class StringSymbol(override val identity: String) : Symbol {
|
||||
}
|
||||
|
||||
/**
|
||||
* An elementary function that could be invoked on a map of arguments
|
||||
* An elementary function that could be invoked on a map of arguments.
|
||||
*
|
||||
* @param T the type this expression takes as argument and returns.
|
||||
*/
|
||||
public fun interface Expression<T> {
|
||||
/**
|
||||
|
@ -68,7 +68,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
): DerivationResult<T> {
|
||||
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
|
||||
|
||||
return SimpleAutoDiffField(this, bindings).derivate(body)
|
||||
return SimpleAutoDiffField(this, bindings).differentiate(body)
|
||||
}
|
||||
|
||||
public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
@ -83,12 +83,21 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
public val context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||
public override val zero: AutoDiffValue<T>
|
||||
get() = const(context.zero)
|
||||
|
||||
public override val one: AutoDiffValue<T>
|
||||
get() = const(context.one)
|
||||
|
||||
// this stack contains pairs of blocks and values to apply them to
|
||||
private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
|
||||
private var sp: Int = 0
|
||||
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
|
||||
|
||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||
}
|
||||
|
||||
/**
|
||||
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
|
||||
* with respect to this variable.
|
||||
@ -106,11 +115,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
override fun hashCode(): Int = identity.hashCode()
|
||||
}
|
||||
|
||||
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
|
||||
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
|
||||
}
|
||||
|
||||
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||
public override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
|
||||
|
||||
private fun getDerivative(variable: AutoDiffValue<T>): T =
|
||||
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
|
||||
@ -119,7 +124,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
|
||||
}
|
||||
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun runBackwardPass() {
|
||||
while (sp > 0) {
|
||||
@ -129,9 +133,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
}
|
||||
}
|
||||
|
||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||
|
||||
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
|
||||
|
||||
/**
|
||||
@ -165,7 +166,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
}
|
||||
|
||||
|
||||
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||
internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
|
||||
val result = function()
|
||||
result.d = context.one // computing derivative w.r.t result
|
||||
runBackwardPass()
|
||||
@ -174,41 +175,41 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
|
||||
// Overloads for Double constants
|
||||
|
||||
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { this@plus.toDouble() * one + b.value }) { z ->
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
||||
public override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
|
||||
|
||||
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
|
||||
|
||||
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||
public override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
|
||||
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
|
||||
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { a.value + b.value }) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { a.value * b.value }) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
}
|
||||
|
||||
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
public override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { a.value / b.value }) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
}
|
||||
|
||||
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
||||
public override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
|
||||
derive(const { k.toDouble() * a.value }) { z ->
|
||||
a.d += z.d * k.toDouble()
|
||||
}
|
||||
@ -220,15 +221,15 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||
public val field: F,
|
||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||
) : FirstDerivativeExpression<T>() {
|
||||
) : FirstDerivativeExpression<T, Expression<T>>() {
|
||||
public override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
return SimpleAutoDiffField(field, arguments).function().value
|
||||
}
|
||||
|
||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||
public override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function)
|
||||
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
|
||||
derivationResult.derivative(symbol)
|
||||
}
|
||||
}
|
||||
@ -236,12 +237,9 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
||||
/**
|
||||
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
|
||||
*/
|
||||
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> {
|
||||
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> {
|
||||
return SimpleAutoDiffExpression(field, function)
|
||||
}
|
||||
}
|
||||
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>, Expression<T>> =
|
||||
AutoDiffProcessor { function ->
|
||||
SimpleAutoDiffExpression(field, function)
|
||||
}
|
||||
|
||||
// Extensions for differentiation of various basic mathematical functions
|
||||
|
9
kmath-kotlingrad/build.gradle.kts
Normal file
9
kmath-kotlingrad/build.gradle.kts
Normal file
@ -0,0 +1,9 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.jvm")
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("com.github.breandan:kaliningraph:0.1.2")
|
||||
implementation("com.github.breandan:kotlingrad:0.3.7")
|
||||
api(project(":kmath-ast"))
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.SFun
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MstAlgebra
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.expressions.DifferentiableExpression
|
||||
import kscience.kmath.expressions.Symbol
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
/**
|
||||
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
|
||||
*
|
||||
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin∇, then converting
|
||||
* [SFun] back to [MST].
|
||||
*
|
||||
* @param T the type of number.
|
||||
* @param A the [NumericAlgebra] of [T].
|
||||
* @property expr the underlying [MstExpression].
|
||||
*/
|
||||
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
||||
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
|
||||
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||
|
||||
/**
|
||||
* The [MstExpression.algebra] of [expr].
|
||||
*/
|
||||
public val algebra: A
|
||||
get() = expr.algebra
|
||||
|
||||
/**
|
||||
* The [MstExpression.mst] of [expr].
|
||||
*/
|
||||
public val mst: MST
|
||||
get() = expr.mst
|
||||
|
||||
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
|
||||
|
||||
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
|
||||
algebra,
|
||||
symbols.map(Symbol::identity)
|
||||
.map(MstAlgebra::symbol)
|
||||
.map { it.toSVar<KMathNumber<T, A>>() }
|
||||
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
|
||||
.toMst(),
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [MstExpression] into [DifferentiableMstExpression].
|
||||
*/
|
||||
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
|
||||
DifferentiableMstExpression(this)
|
@ -0,0 +1,18 @@
|
||||
package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.RealNumber
|
||||
import edu.umontreal.kotlingrad.experimental.SConst
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
|
||||
/**
|
||||
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
|
||||
*
|
||||
* @param T the type of number.
|
||||
* @param A the [NumericAlgebra] of [T].
|
||||
* @property algebra the algebra.
|
||||
* @param value the value of this number.
|
||||
*/
|
||||
public class KMathNumber<T, A>(public val algebra: A, value: T) :
|
||||
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
|
||||
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
|
||||
}
|
@ -0,0 +1,124 @@
|
||||
package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.*
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MstAlgebra
|
||||
import kscience.kmath.ast.MstExtendedField
|
||||
import kscience.kmath.ast.MstExtendedField.unaryMinus
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
* Maps [SVar] to [MST.Symbolic] directly.
|
||||
*
|
||||
* @receiver the variable.
|
||||
* @return a node.
|
||||
*/
|
||||
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
|
||||
|
||||
/**
|
||||
* Maps [SVar] to [MST.Numeric] directly.
|
||||
*
|
||||
* @receiver the constant.
|
||||
* @return a node.
|
||||
*/
|
||||
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
|
||||
|
||||
/**
|
||||
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
|
||||
* [Power] operation is limited to constant right-hand side arguments.
|
||||
*
|
||||
* Detailed mapping is:
|
||||
*
|
||||
* - [SVar] -> [MstExtendedField.symbol];
|
||||
* - [SConst] -> [MstExtendedField.number];
|
||||
* - [Sum] -> [MstExtendedField.add];
|
||||
* - [Prod] -> [MstExtendedField.multiply];
|
||||
* - [Power] -> [MstExtendedField.power] (limited to constant exponents only);
|
||||
* - [Negative] -> [MstExtendedField.unaryMinus];
|
||||
* - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right);
|
||||
* - [Sine] -> [MstExtendedField.sin];
|
||||
* - [Cosine] -> [MstExtendedField.cos];
|
||||
* - [Tangent] -> [MstExtendedField.tan];
|
||||
* - [DProd] is vector operation, and it is requested to be evaluated;
|
||||
* - [SComposition] is also requested to be evaluated eagerly;
|
||||
* - [VSumAll] is requested to be evaluated;
|
||||
* - [Derivative] is requested to be evaluated.
|
||||
*
|
||||
* @receiver the scalar function.
|
||||
* @return a node.
|
||||
*/
|
||||
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
|
||||
when (this@toMst) {
|
||||
is SVar -> toMst()
|
||||
is SConst -> toMst()
|
||||
is Sum -> left.toMst() + right.toMst()
|
||||
is Prod -> left.toMst() * right.toMst()
|
||||
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
|
||||
is Negative -> -input.toMst()
|
||||
is Log -> ln(left.toMst()) / ln(right.toMst())
|
||||
is Sine -> sin(input.toMst())
|
||||
is Cosine -> cos(input.toMst())
|
||||
is Tangent -> tan(input.toMst())
|
||||
is DProd -> this@toMst().toMst()
|
||||
is SComposition -> this@toMst().toMst()
|
||||
is VSumAll<X, *> -> this@toMst().toMst()
|
||||
is Derivative -> this@toMst().toMst()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps [MST.Numeric] to [SConst] directly.
|
||||
*
|
||||
* @receiver the node.
|
||||
* @return a new constant.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
|
||||
|
||||
/**
|
||||
* Maps [MST.Symbolic] to [SVar] directly.
|
||||
*
|
||||
* @receiver the node.
|
||||
* @param proto the prototype instance.
|
||||
* @return a new variable.
|
||||
*/
|
||||
internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
|
||||
|
||||
/**
|
||||
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
|
||||
*
|
||||
* Detailed mapping is:
|
||||
*
|
||||
* - [MST.Numeric] -> [SConst];
|
||||
* - [MST.Symbolic] -> [SVar];
|
||||
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
|
||||
* - [MST.Binary] -> [Sum], [Prod], [Power].
|
||||
*
|
||||
* @receiver the node.
|
||||
* @param proto the prototype instance.
|
||||
* @return a scalar function.
|
||||
*/
|
||||
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
is MST.Numeric -> toSConst()
|
||||
is MST.Symbolic -> toSVar()
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
|
||||
SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
|
||||
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
|
||||
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
|
||||
else -> error("Unary operation $operation not defined in $this")
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
package kscience.kmath.kotlingrad
|
||||
|
||||
import edu.umontreal.kotlingrad.experimental.*
|
||||
import kscience.kmath.asm.compile
|
||||
import kscience.kmath.ast.MstAlgebra
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.ast.parseMath
|
||||
import kscience.kmath.expressions.invoke
|
||||
import kscience.kmath.operations.RealField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
import kotlin.test.fail
|
||||
|
||||
internal class AdaptingTests {
|
||||
@Test
|
||||
fun symbol() {
|
||||
val c1 = MstAlgebra.symbol("x")
|
||||
assertTrue(c1.toSVar<KMathNumber<Double, RealField>>().name == "x")
|
||||
val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun number() {
|
||||
val c1 = MstAlgebra.number(12354324)
|
||||
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
|
||||
val c2 = "0.234".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
|
||||
val c3 = "1e-3".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun simpleFunctionShape() {
|
||||
val linear = "2*x+16".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||
if (linear !is Sum) fail()
|
||||
Better to use AssertFalse Better to use AssertFalse
No. No. `fail` is better since it returns Nothing, so Kotlin DFA makes smart-cast of `linear` to `Sum`
|
||||
if (linear.left !is Prod) fail()
|
||||
if (linear.right !is SConst) fail()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun simpleFunctionDerivative() {
|
||||
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
|
||||
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
|
||||
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
|
||||
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun moreComplexDerivative() {
|
||||
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
|
||||
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, RealField>>()
|
||||
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
|
||||
|
||||
val expectedDerivative = MstExpression(
|
||||
RealField,
|
||||
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
|
||||
).compile()
|
||||
|
||||
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
|
||||
}
|
||||
}
|
@ -1,4 +1,6 @@
|
||||
plugins { id("ru.mipt.npm.mpp") }
|
||||
plugins {
|
||||
id("ru.mipt.npm.mpp")
|
||||
}
|
||||
|
||||
kotlin.sourceSets {
|
||||
commonMain {
|
||||
|
@ -12,16 +12,18 @@ public object Fitting {
|
||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||
*/
|
||||
public fun <T : Any, I : Any, A> chiSquared(
|
||||
autoDiff: AutoDiffProcessor<T, I, A>,
|
||||
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||
x: Buffer<T>,
|
||||
y: Buffer<T>,
|
||||
yErr: Buffer<T>,
|
||||
model: A.(I) -> I,
|
||||
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||
): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||
|
||||
return autoDiff.process {
|
||||
var sum = zero
|
||||
|
||||
x.indices.forEach {
|
||||
val xValue = const(x[it])
|
||||
val yValue = const(y[it])
|
||||
@ -29,6 +31,7 @@ public object Fitting {
|
||||
val modelValue = model(xValue)
|
||||
sum += ((yValue - modelValue) / yErrValue).pow(2)
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
}
|
||||
@ -45,6 +48,7 @@ public object Fitting {
|
||||
): Expression<Double> {
|
||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||
|
||||
return Expression { arguments ->
|
||||
x.indices.sumByDouble {
|
||||
val xValue = x[it]
|
||||
|
@ -27,17 +27,17 @@ public interface OptimizationProblem<T : Any> {
|
||||
/**
|
||||
* Define the initial guess for the optimization problem
|
||||
*/
|
||||
public fun initialGuess(map: Map<Symbol, T>): Unit
|
||||
public fun initialGuess(map: Map<Symbol, T>)
|
||||
|
||||
/**
|
||||
* Set an objective function expression
|
||||
*/
|
||||
public fun expression(expression: Expression<T>): Unit
|
||||
public fun expression(expression: Expression<T>)
|
||||
|
||||
/**
|
||||
* Set a differentiable expression as objective function as function and gradient provider
|
||||
*/
|
||||
public fun diffExpression(expression: DifferentiableExpression<T>): Unit
|
||||
public fun diffExpression(expression: DifferentiableExpression<T, Expression<T>>)
|
||||
|
||||
/**
|
||||
* Update the problem from previous optimization run
|
||||
@ -50,9 +50,8 @@ public interface OptimizationProblem<T : Any> {
|
||||
public fun optimize(): OptimizationResult<T>
|
||||
}
|
||||
|
||||
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
||||
public fun interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
|
||||
public fun build(symbols: List<Symbol>): P
|
||||
|
||||
}
|
||||
|
||||
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
|
||||
@ -60,7 +59,6 @@ public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFac
|
||||
block: P.() -> Unit,
|
||||
): P = build(symbols).apply(block)
|
||||
|
||||
|
||||
/**
|
||||
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
||||
*/
|
||||
@ -78,7 +76,7 @@ public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
|
||||
/**
|
||||
* Optimize differentiable expression using specific [OptimizationProblemFactory]
|
||||
*/
|
||||
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith(
|
||||
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
|
||||
factory: OptimizationProblemFactory<T, F>,
|
||||
vararg symbols: Symbol,
|
||||
configuration: F.() -> Unit,
|
||||
@ -88,4 +86,3 @@ public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.op
|
||||
problem.diffExpression(this)
|
||||
return problem.optimize()
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
plugins { id("ru.mipt.npm.jvm") }
|
||||
plugins {
|
||||
id("ru.mipt.npm.jvm")
|
||||
}
|
||||
|
||||
description = "Binding for https://github.com/JetBrains-Research/viktor"
|
||||
|
||||
|
@ -1,13 +1,11 @@
|
||||
pluginManagement {
|
||||
repositories {
|
||||
mavenLocal()
|
||||
jcenter()
|
||||
gradlePluginPortal()
|
||||
jcenter()
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
|
||||
}
|
||||
|
||||
val toolsVersion = "0.6.4-dev-1.4.20-M2"
|
||||
@ -41,5 +39,6 @@ include(
|
||||
":kmath-geometry",
|
||||
":kmath-ast",
|
||||
":kmath-ejml",
|
||||
":kmath-kotlingrad",
|
||||
":examples"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user
I do not like jitpack dependency here as well as other custom dependencies. You'd better move it to specific projects and explicitly list them in README both for the module and for examples.