Merge pull request #150 from mipt-npm/kotlingrad

Add adapters of scalar functions to MST and vice versa
This commit is contained in:
Iaroslav Postovalov 2020-11-02 01:09:44 +07:00 committed by GitHub
commit abe68a4fb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 508 additions and 155 deletions

View File

@ -211,7 +211,15 @@ Release artifacts are accessible from bintray with following configuration (see
```kotlin ```kotlin
repositories { repositories {
jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io")
mavenCentral()
} }
dependencies { dependencies {
@ -228,7 +236,15 @@ Development builds are uploaded to the separate repository:
```kotlin ```kotlin
repositories { repositories {
jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://jitpack.io")
mavenCentral()
} }
``` ```

View File

@ -1,3 +1,5 @@
import ru.mipt.npm.gradle.KSciencePublishPlugin
plugins { plugins {
id("ru.mipt.npm.project") id("ru.mipt.npm.project")
} }
@ -9,9 +11,16 @@ internal val githubProject: String by extra("kmath")
allprojects { allprojects {
repositories { repositories {
jcenter() jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io")
maven("http://logicrunch.research.it.uu.se/maven/")
mavenCentral()
} }
group = "kscience.kmath" group = "kscience.kmath"
@ -19,7 +28,7 @@ allprojects {
} }
subprojects { subprojects {
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>() if (name.startsWith("kmath")) apply<KSciencePublishPlugin>()
} }
readme { readme {

View File

@ -8,18 +8,25 @@ plugins {
} }
allOpen.annotation("org.openjdk.jmh.annotations.State") allOpen.annotation("org.openjdk.jmh.annotations.State")
sourceSets.register("benchmarks")
repositories { repositories {
maven("https://dl.bintray.com/mipt-npm/kscience") jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlin-dev/") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io")
maven("http://logicrunch.research.it.uu.se/maven/")
mavenCentral() mavenCentral()
} }
sourceSets.register("benchmarks")
dependencies { dependencies {
implementation(project(":kmath-ast")) implementation(project(":kmath-ast"))
implementation(project(":kmath-kotlingrad"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))

View File

@ -9,7 +9,7 @@ import kscience.kmath.operations.RealField
import kotlin.random.Random import kotlin.random.Random
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
class ExpressionsInterpretersBenchmark { internal class ExpressionsInterpretersBenchmark {
private val algebra: Field<Double> = RealField private val algebra: Field<Double> = RealField
fun functionalExpression() { fun functionalExpression() {
val expr = algebra.expressionInField { val expr = algebra.expressionInField {
@ -47,6 +47,16 @@ class ExpressionsInterpretersBenchmark {
} }
} }
/**
* This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
* core FunctionalExpressions API.
*
* The expected rating is:
*
* 1. ASM.
* 2. MST.
* 3. FE.
*/
fun main() { fun main() {
val benchmark = ExpressionsInterpretersBenchmark() val benchmark = ExpressionsInterpretersBenchmark()

View File

@ -0,0 +1,24 @@
package kscience.kmath.ast
import kscience.kmath.asm.compile
import kscience.kmath.expressions.derivative
import kscience.kmath.expressions.invoke
import kscience.kmath.expressions.symbol
import kscience.kmath.kotlingrad.differentiable
import kscience.kmath.operations.RealField
/**
* In this example, x^2-4*x-44 function is differentiated with Kotlin, and the autodiff result is compared with
* valid derivative.
*/
fun main() {
val x by symbol
val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath())
.differentiable()
.derivative(x)
.compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
}

View File

@ -6,14 +6,14 @@ import kscience.kmath.operations.*
* [Algebra] over [MST] nodes. * [Algebra] over [MST] nodes.
*/ */
public object MstAlgebra : NumericAlgebra<MST> { public object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value) override fun number(value: Number): MST.Numeric = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value) override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST = override fun unaryOperation(operation: String, arg: MST): MST.Unary =
MST.Unary(operation, arg) MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MST.Binary(operation, left, right) MST.Binary(operation, left, right)
} }
@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra<MST> {
* [Space] over [MST] nodes. * [Space] over [MST] nodes.
*/ */
public object MstSpace : Space<MST>, NumericAlgebra<MST> { public object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST.Numeric by lazy { number(0.0) }
override fun number(value: Number): MST = MstAlgebra.number(value) override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstAlgebra.binaryOperation(operation, left, right) MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
} }
/** /**
* [Ring] over [MST] nodes. * [Ring] over [MST] nodes.
*/ */
public object MstRing : Ring<MST>, NumericAlgebra<MST> { public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST override val zero: MST.Numeric
get() = MstSpace.zero get() = MstSpace.zero
override val one: MST = number(1.0)
override fun number(value: Number): MST = MstSpace.number(value) override val one: MST.Numeric by lazy { number(1.0) }
override fun symbol(value: String): MST = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) override fun number(value: Number): MST.Numeric = MstSpace.number(value)
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstSpace.binaryOperation(operation, left, right) MstSpace.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
} }
/** /**
* [Field] over [MST] nodes. * [Field] over [MST] nodes.
*/ */
public object MstField : Field<MST> { public object MstField : Field<MST> {
public override val zero: MST public override val zero: MST.Numeric
get() = MstRing.zero get() = MstRing.zero
public override val one: MST public override val one: MST.Numeric
get() = MstRing.one get() = MstRing.one
public override fun symbol(value: String): MST = MstRing.symbol(value) public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
public override fun number(value: Number): MST = MstRing.number(value) public override fun number(value: Number): MST.Numeric = MstRing.number(value)
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
public override fun binaryOperation(operation: String, left: MST, right: MST): MST = public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstRing.binaryOperation(operation, left, right) MstRing.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
} }
/** /**
* [ExtendedField] over [MST] nodes. * [ExtendedField] over [MST] nodes.
*/ */
public object MstExtendedField : ExtendedField<MST> { public object MstExtendedField : ExtendedField<MST> {
override val zero: MST override val zero: MST.Numeric
get() = MstField.zero get() = MstField.zero
override val one: MST override val one: MST.Numeric
get() = MstField.one get() = MstField.one
override fun symbol(value: String): MST = MstField.symbol(value) override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun number(value: Number): MST.Numeric = MstField.number(value)
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
override fun add(a: MST, b: MST): MST = MstField.add(a, b) override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstField.binaryOperation(operation, left, right) MstField.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
} }

View File

@ -13,7 +13,7 @@ import kotlin.contracts.contract
* @property mst the [MST] node. * @property mst the [MST] node.
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> { public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
override fun binaryOperation(operation: String, left: T, right: T): T = override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right) algebra.binaryOperation(operation, left, right)
override fun number(value: Number): T = if (algebra is NumericAlgebra) @Suppress("UNCHECKED_CAST")
algebra.number(value) override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else else
error("Numeric nodes are not supported by $this") error("Numeric nodes are not supported by $this")
} }
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst( public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E, mstAlgebra: E,
block: E.() -> MST, block: E.() -> MST,
): MstExpression<T> = MstExpression(this, mstAlgebra.block()) ): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
/** /**
* Builds [MstExpression] over [Space]. * Builds [MstExpression] over [Space].
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstSpace.block()) return MstExpression(this, MstSpace.block())
} }
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block()) return MstExpression(this, MstRing.block())
} }
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block()) return MstExpression(this, MstField.block())
} }
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block()) return MstExpression(this, MstExtendedField.block())
} }
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInSpace(block) return algebra.mstInSpace(block)
} }
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block) return algebra.mstInRing(block)
} }
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block) return algebra.mstInField(block)
} }
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
*/ */
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField( public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST, block: MstExtendedField.() -> MST,
): MstExpression<T> { ): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block) return algebra.mstInExtendedField(block)
} }

View File

@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
* *
* @author Alexander Nozik. * @author Alexander Nozik.
*/ */
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra) public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(T::class.java, algebra)

View File

@ -95,10 +95,10 @@ public class DerivativeStructureField(
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> { public companion object :
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> { AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
return DerivativeStructureExpression(function) public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
} DerivativeStructureExpression(function)
} }
} }
@ -108,7 +108,7 @@ public class DerivativeStructureField(
*/ */
public class DerivativeStructureExpression( public class DerivativeStructureExpression(
public val function: DerivativeStructureField.() -> DerivativeStructure, public val function: DerivativeStructureField.() -> DerivativeStructure,
) : DifferentiableExpression<Double> { ) : DifferentiableExpression<Double, Expression<Double>> {
public override operator fun invoke(arguments: Map<Symbol, Double>): Double = public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
DerivativeStructureField(0, arguments).function().value DerivativeStructureField(0, arguments).function().value

View File

@ -19,9 +19,8 @@ import kotlin.reflect.KClass
public operator fun PointValuePair.component1(): DoubleArray = point public operator fun PointValuePair.component1(): DoubleArray = point
public operator fun PointValuePair.component2(): Double = value public operator fun PointValuePair.component2(): Double = value
public class CMOptimizationProblem( public class CMOptimizationProblem(override val symbols: List<Symbol>, ) :
override val symbols: List<Symbol>, OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap() private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE, public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
@ -49,7 +48,7 @@ public class CMOptimizationProblem(
addOptimizationData(objectiveFunction) addOptimizationData(objectiveFunction)
} }
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit { public override fun diffExpression(expression: DifferentiableExpression<Double, Expression<Double>>) {
expression(expression) expression(expression)
val gradientFunction = ObjectiveFunctionGradient { val gradientFunction = ObjectiveFunctionGradient {
val args = it.toMap() val args = it.toMap()

View File

@ -12,7 +12,6 @@ import kscience.kmath.structures.asBuffer
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
/** /**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/ */
@ -21,7 +20,7 @@ public fun Fitting.chiSquared(
y: Buffer<Double>, y: Buffer<Double>,
yErr: Buffer<Double>, yErr: Buffer<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model) ): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
/** /**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
@ -31,7 +30,7 @@ public fun Fitting.chiSquared(
y: Iterable<Double>, y: Iterable<Double>,
yErr: Iterable<Double>, yErr: Iterable<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double> = chiSquared( ): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
DerivativeStructureField, DerivativeStructureField,
x.toList().asBuffer(), x.toList().asBuffer(),
y.toList().asBuffer(), y.toList().asBuffer(),
@ -39,7 +38,6 @@ public fun Fitting.chiSquared(
model model
) )
/** /**
* Optimize expression without derivatives * Optimize expression without derivatives
*/ */
@ -48,16 +46,15 @@ public fun Expression<Double>.optimize(
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) ): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
/** /**
* Optimize differentiable expression * Optimize differentiable expression
*/ */
public fun DifferentiableExpression<Double>.optimize( public fun DifferentiableExpression<Double, Expression<Double>>.optimize(
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) ): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
public fun DifferentiableExpression<Double>.minimize( public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
vararg startPoint: Pair<Symbol, Double>, vararg startPoint: Pair<Symbol, Double>,
configuration: CMOptimizationProblem.() -> Unit = {}, configuration: CMOptimizationProblem.() -> Unit = {},
): OptimizationResult<Double> { ): OptimizationResult<Double> {

View File

@ -47,14 +47,17 @@ internal class OptimizeTest {
val sigma = 1.0 val sigma = 1.0
val generator = Distribution.normal(0.0, sigma) val generator = Distribution.normal(0.0, sigma)
val chain = generator.sample(RandomGenerator.default(112667)) val chain = generator.sample(RandomGenerator.default(112667))
val x = (1..100).map { it.toDouble() } val x = (1..100).map(Int::toDouble)
val y = x.map { it ->
val y = x.map {
it.pow(2) + it + 1 + chain.nextDouble() it.pow(2) + it + 1 + chain.nextDouble()
} }
val yErr = x.map { sigma }
val chi2 = Fitting.chiSquared(x, y, yErr) { x -> val yErr = List(x.size) { sigma }
val chi2 = Fitting.chiSquared(x, y, yErr) { x1 ->
val cWithDefault = bindOrNull(c) ?: one val cWithDefault = bindOrNull(c) ?: one
bind(a) * x.pow(2) + bind(b) * x + cWithDefault bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault
} }
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)

View File

@ -1,29 +1,40 @@
package kscience.kmath.expressions package kscience.kmath.expressions
/** /**
* An expression that provides derivatives * Represents expression which structure can be differentiated.
*
* @param T the type this expression takes as argument and returns.
* @param R the type of expression this expression can be differentiated to.
*/ */
public interface DifferentiableExpression<T> : Expression<T> { public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> {
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? /**
* Differentiates this expression by ordered collection of [symbols].
*
* @param symbols the symbols.
* @return the derivative or `null`.
*/
public fun derivativeOrNull(symbols: List<Symbol>): R?
} }
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> = public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> = public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
derivative(symbols.toList()) derivative(symbols.toList())
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
derivative(StringSymbol(name)) derivative(StringSymbol(name))
/** /**
* A [DifferentiableExpression] that defines only first derivatives * A [DifferentiableExpression] that defines only first derivatives
*/ */
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> { public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T,R> {
/**
* Returns first derivative of this expression by given [symbol].
*/
public abstract fun derivativeOrNull(symbol: Symbol): R?
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>? public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
val dSymbol = symbols.firstOrNull() ?: return null val dSymbol = symbols.firstOrNull() ?: return null
return derivativeOrNull(dSymbol) return derivativeOrNull(dSymbol)
} }
@ -32,6 +43,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
/** /**
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression] * A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/ */
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> { public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
public fun process(function: A.() -> I): DifferentiableExpression<T> public fun process(function: A.() -> I): DifferentiableExpression<T, R>
} }

View File

@ -22,7 +22,9 @@ public inline class StringSymbol(override val identity: String) : Symbol {
} }
/** /**
* An elementary function that could be invoked on a map of arguments * An elementary function that could be invoked on a map of arguments.
*
* @param T the type this expression takes as argument and returns.
*/ */
public fun interface Expression<T> { public fun interface Expression<T> {
/** /**

View File

@ -68,7 +68,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
): DerivationResult<T> { ): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return SimpleAutoDiffField(this, bindings).derivate(body) return SimpleAutoDiffField(this, bindings).differentiate(body)
} }
public fun <T : Any, F : Field<T>> F.simpleAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
@ -83,12 +83,21 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F, public val context: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> { ) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public override val zero: AutoDiffValue<T>
get() = const(context.zero)
public override val one: AutoDiffValue<T>
get() = const(context.one)
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8) private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0 private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf() private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
/** /**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable. * with respect to this variable.
@ -106,11 +115,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
override fun hashCode(): Int = identity.hashCode() override fun hashCode(): Int = identity.hashCode()
} }
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate { public override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
private fun getDerivative(variable: AutoDiffValue<T>): T = private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
@ -119,7 +124,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
} }
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
private fun runBackwardPass() { private fun runBackwardPass() {
while (sp > 0) { while (sp > 0) {
@ -129,9 +133,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
} }
} }
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value) override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
/** /**
@ -165,7 +166,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
} }
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> { internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function() val result = function()
result.d = context.one // computing derivative w.r.t result result.d = context.one // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
@ -174,41 +175,41 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
// Overloads for Double constants // Overloads for Double constants
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> = public override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@plus.toDouble() * one + b.value }) { z -> derive(const { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d b.d += z.d
} }
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this) public override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> = public override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> = public override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = public override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z -> derive(const { a.value + b.value }) { z ->
a.d += z.d a.d += z.d
b.d += z.d b.d += z.d
} }
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = public override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z -> derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value a.d += z.d * b.value
b.d += z.d * a.value b.d += z.d * a.value
} }
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = public override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z -> derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value) b.d -= z.d * a.value / (b.value * b.value)
} }
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> = public override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
derive(const { k.toDouble() * a.value }) { z -> derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
@ -220,15 +221,15 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>( public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F, public val field: F,
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>, public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T>() { ) : FirstDerivativeExpression<T, Expression<T>>() {
public override operator fun invoke(arguments: Map<Symbol, T>): T { public override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
return SimpleAutoDiffField(field, arguments).function().value return SimpleAutoDiffField(field, arguments).function().value
} }
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments -> public override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function) val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
derivationResult.derivative(symbol) derivationResult.derivative(symbol)
} }
} }
@ -236,13 +237,10 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
/** /**
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
*/ */
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> { public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>, Expression<T>> =
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> { AutoDiffProcessor { function ->
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> { SimpleAutoDiffExpression(field, function)
return SimpleAutoDiffExpression(field, function)
}
} }
}
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
@ -392,4 +390,4 @@ public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> = public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atanh(arg) (this as SimpleAutoDiffField<T, F>).atanh(arg)
} }

View File

@ -0,0 +1,9 @@
plugins {
id("ru.mipt.npm.jvm")
}
dependencies {
implementation("com.github.breandan:kaliningraph:0.1.2")
implementation("com.github.breandan:kotlingrad:0.3.7")
api(project(":kmath-ast"))
}

View File

@ -0,0 +1,53 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.SFun
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.NumericAlgebra
/**
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
*
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin, then converting
* [SFun] back to [MST].
*
* @param T the type of number.
* @param A the [NumericAlgebra] of [T].
* @property expr the underlying [MstExpression].
*/
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
/**
* The [MstExpression.algebra] of [expr].
*/
public val algebra: A
get() = expr.algebra
/**
* The [MstExpression.mst] of [expr].
*/
public val mst: MST
get() = expr.mst
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::symbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* Wraps this [MstExpression] into [DifferentiableMstExpression].
*/
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(this)

View File

@ -0,0 +1,18 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.RealNumber
import edu.umontreal.kotlingrad.experimental.SConst
import kscience.kmath.operations.NumericAlgebra
/**
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
*
* @param T the type of number.
* @param A the [NumericAlgebra] of [T].
* @property algebra the algebra.
* @param value the value of this number.
*/
public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
}

View File

@ -0,0 +1,124 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExtendedField
import kscience.kmath.ast.MstExtendedField.unaryMinus
import kscience.kmath.operations.*
/**
* Maps [SVar] to [MST.Symbolic] directly.
*
* @receiver the variable.
* @return a node.
*/
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
/**
* Maps [SVar] to [MST.Numeric] directly.
*
* @receiver the constant.
* @return a node.
*/
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
/**
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
* [Power] operation is limited to constant right-hand side arguments.
*
* Detailed mapping is:
*
* - [SVar] -> [MstExtendedField.symbol];
* - [SConst] -> [MstExtendedField.number];
* - [Sum] -> [MstExtendedField.add];
* - [Prod] -> [MstExtendedField.multiply];
* - [Power] -> [MstExtendedField.power] (limited to constant exponents only);
* - [Negative] -> [MstExtendedField.unaryMinus];
* - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right);
* - [Sine] -> [MstExtendedField.sin];
* - [Cosine] -> [MstExtendedField.cos];
* - [Tangent] -> [MstExtendedField.tan];
* - [DProd] is vector operation, and it is requested to be evaluated;
* - [SComposition] is also requested to be evaluated eagerly;
* - [VSumAll] is requested to be evaluated;
* - [Derivative] is requested to be evaluated.
*
* @receiver the scalar function.
* @return a node.
*/
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
when (this@toMst) {
is SVar -> toMst()
is SConst -> toMst()
is Sum -> left.toMst() + right.toMst()
is Prod -> left.toMst() * right.toMst()
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
is Negative -> -input.toMst()
is Log -> ln(left.toMst()) / ln(right.toMst())
is Sine -> sin(input.toMst())
is Cosine -> cos(input.toMst())
is Tangent -> tan(input.toMst())
is DProd -> this@toMst().toMst()
is SComposition -> this@toMst().toMst()
is VSumAll<X, *> -> this@toMst().toMst()
is Derivative -> this@toMst().toMst()
}
}
/**
* Maps [MST.Numeric] to [SConst] directly.
*
* @receiver the node.
* @return a new constant.
*/
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
/**
* Maps [MST.Symbolic] to [SVar] directly.
*
* @receiver the node.
* @param proto the prototype instance.
* @return a new variable.
*/
internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
/**
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
*
* Detailed mapping is:
*
* - [MST.Numeric] -> [SConst];
* - [MST.Symbolic] -> [SVar];
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
* - [MST.Binary] -> [Sum], [Prod], [Power].
*
* @receiver the node.
* @param proto the prototype instance.
* @return a scalar function.
*/
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is MST.Numeric -> toSConst()
is MST.Symbolic -> toSVar()
is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
else -> error("Unary operation $operation not defined in $this")
}
is MST.Binary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this")
}
}

View File

@ -0,0 +1,64 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.asm.compile
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.ast.parseMath
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail
internal class AdaptingTests {
@Test
fun symbol() {
val c1 = MstAlgebra.symbol("x")
assertTrue(c1.toSVar<KMathNumber<Double, RealField>>().name == "x")
val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
}
@Test
fun number() {
val c1 = MstAlgebra.number(12354324)
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
val c2 = "0.234".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
val c3 = "1e-3".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
}
@Test
fun simpleFunctionShape() {
val linear = "2*x+16".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (linear !is Sum) fail()
if (linear.left !is Prod) fail()
if (linear.right !is SConst) fail()
}
@Test
fun simpleFunctionDerivative() {
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, RealField>>()
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
}
@Test
fun moreComplexDerivative() {
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, RealField>>()
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
val expectedDerivative = MstExpression(
RealField,
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
).compile()
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
}
}

View File

@ -1,4 +1,6 @@
plugins { id("ru.mipt.npm.mpp") } plugins {
id("ru.mipt.npm.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {

View File

@ -12,16 +12,18 @@ public object Fitting {
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/ */
public fun <T : Any, I : Any, A> chiSquared( public fun <T : Any, I : Any, A> chiSquared(
autoDiff: AutoDiffProcessor<T, I, A>, autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
x: Buffer<T>, x: Buffer<T>,
y: Buffer<T>, y: Buffer<T>,
yErr: Buffer<T>, yErr: Buffer<T>,
model: A.(I) -> I, model: A.(I) -> I,
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> { ): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
require(x.size == y.size) { "X and y buffers should be of the same size" } require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
return autoDiff.process { return autoDiff.process {
var sum = zero var sum = zero
x.indices.forEach { x.indices.forEach {
val xValue = const(x[it]) val xValue = const(x[it])
val yValue = const(y[it]) val yValue = const(y[it])
@ -29,6 +31,7 @@ public object Fitting {
val modelValue = model(xValue) val modelValue = model(xValue)
sum += ((yValue - modelValue) / yErrValue).pow(2) sum += ((yValue - modelValue) / yErrValue).pow(2)
} }
sum sum
} }
} }
@ -45,6 +48,7 @@ public object Fitting {
): Expression<Double> { ): Expression<Double> {
require(x.size == y.size) { "X and y buffers should be of the same size" } require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
return Expression { arguments -> return Expression { arguments ->
x.indices.sumByDouble { x.indices.sumByDouble {
val xValue = x[it] val xValue = x[it]
@ -56,4 +60,4 @@ public object Fitting {
} }
} }
} }
} }

View File

@ -27,17 +27,17 @@ public interface OptimizationProblem<T : Any> {
/** /**
* Define the initial guess for the optimization problem * Define the initial guess for the optimization problem
*/ */
public fun initialGuess(map: Map<Symbol, T>): Unit public fun initialGuess(map: Map<Symbol, T>)
/** /**
* Set an objective function expression * Set an objective function expression
*/ */
public fun expression(expression: Expression<T>): Unit public fun expression(expression: Expression<T>)
/** /**
* Set a differentiable expression as objective function as function and gradient provider * Set a differentiable expression as objective function as function and gradient provider
*/ */
public fun diffExpression(expression: DifferentiableExpression<T>): Unit public fun diffExpression(expression: DifferentiableExpression<T, Expression<T>>)
/** /**
* Update the problem from previous optimization run * Update the problem from previous optimization run
@ -50,9 +50,8 @@ public interface OptimizationProblem<T : Any> {
public fun optimize(): OptimizationResult<T> public fun optimize(): OptimizationResult<T>
} }
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> { public fun interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
public fun build(symbols: List<Symbol>): P public fun build(symbols: List<Symbol>): P
} }
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke( public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
@ -60,7 +59,6 @@ public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFac
block: P.() -> Unit, block: P.() -> Unit,
): P = build(symbols).apply(block) ): P = build(symbols).apply(block)
/** /**
* Optimize expression without derivatives using specific [OptimizationProblemFactory] * Optimize expression without derivatives using specific [OptimizationProblemFactory]
*/ */
@ -78,7 +76,7 @@ public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
/** /**
* Optimize differentiable expression using specific [OptimizationProblemFactory] * Optimize differentiable expression using specific [OptimizationProblemFactory]
*/ */
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith( public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
factory: OptimizationProblemFactory<T, F>, factory: OptimizationProblemFactory<T, F>,
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: F.() -> Unit, configuration: F.() -> Unit,
@ -88,4 +86,3 @@ public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.op
problem.diffExpression(this) problem.diffExpression(this)
return problem.optimize() return problem.optimize()
} }

View File

@ -1,4 +1,6 @@
plugins { id("ru.mipt.npm.jvm") } plugins {
id("ru.mipt.npm.jvm")
}
description = "Binding for https://github.com/JetBrains-Research/viktor" description = "Binding for https://github.com/JetBrains-Research/viktor"

View File

@ -1,13 +1,11 @@
pluginManagement { pluginManagement {
repositories { repositories {
mavenLocal()
jcenter()
gradlePluginPortal() gradlePluginPortal()
jcenter()
maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
} }
val toolsVersion = "0.6.4-dev-1.4.20-M2" val toolsVersion = "0.6.4-dev-1.4.20-M2"
@ -41,5 +39,6 @@ include(
":kmath-geometry", ":kmath-geometry",
":kmath-ast", ":kmath-ast",
":kmath-ejml", ":kmath-ejml",
":kmath-kotlingrad",
":examples" ":examples"
) )