Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
Alexander Nozik 2020-11-22 19:06:50 +03:00
commit 1538bc0e69
27 changed files with 668 additions and 225 deletions

View File

@ -211,7 +211,15 @@ Release artifacts are accessible from bintray with following configuration (see
```kotlin ```kotlin
repositories { repositories {
jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io")
mavenCentral()
} }
dependencies { dependencies {
@ -228,7 +236,15 @@ Development builds are uploaded to the separate repository:
```kotlin ```kotlin
repositories { repositories {
jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://jitpack.io")
mavenCentral()
} }
``` ```

View File

@ -1,3 +1,5 @@
import ru.mipt.npm.gradle.KSciencePublishPlugin
plugins { plugins {
id("ru.mipt.npm.project") id("ru.mipt.npm.project")
} }
@ -9,9 +11,16 @@ internal val githubProject: String by extra("kmath")
allprojects { allprojects {
repositories { repositories {
jcenter() jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/hotkeytlt/maven") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io")
maven("http://logicrunch.research.it.uu.se/maven/")
mavenCentral()
} }
group = "kscience.kmath" group = "kscience.kmath"
@ -19,7 +28,7 @@ allprojects {
} }
subprojects { subprojects {
if (name.startsWith("kmath")) apply<ru.mipt.npm.gradle.KSciencePublishPlugin>() if (name.startsWith("kmath")) apply<KSciencePublishPlugin>()
} }
readme { readme {

View File

@ -8,18 +8,25 @@ plugins {
} }
allOpen.annotation("org.openjdk.jmh.annotations.State") allOpen.annotation("org.openjdk.jmh.annotations.State")
sourceSets.register("benchmarks")
repositories { repositories {
maven("https://dl.bintray.com/mipt-npm/kscience") jcenter()
maven("https://clojars.org/repo")
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
maven("https://dl.bintray.com/hotkeytlt/maven")
maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlin-dev/") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://jitpack.io")
maven("http://logicrunch.research.it.uu.se/maven/")
mavenCentral() mavenCentral()
} }
sourceSets.register("benchmarks")
dependencies { dependencies {
// implementation(project(":kmath-ast")) implementation(project(":kmath-ast"))
implementation(project(":kmath-kotlingrad"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
@ -27,6 +34,20 @@ dependencies {
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml")) implementation(project(":kmath-ejml"))
implementation(project(":kmath-nd4j"))
implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
// uncomment if your system supports AVX2
// val os = System.getProperty("os.name")
//
// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when {
// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2")
// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2")
// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2")
// } else
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11") implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
implementation("org.slf4j:slf4j-simple:1.7.30") implementation("org.slf4j:slf4j-simple:1.7.30")
@ -55,4 +76,6 @@ kotlin.sourceSets.all {
} }
} }
tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" } tasks.withType<KotlinCompile> {
kotlinOptions.jvmTarget = "11"
}

View File

@ -1,70 +1,80 @@
package kscience.kmath.ast package kscience.kmath.ast
//
//import kscience.kmath.asm.compile import kscience.kmath.asm.compile
//import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
//import kscience.kmath.expressions.expressionInField import kscience.kmath.expressions.expressionInField
//import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
//import kscience.kmath.operations.Field import kscience.kmath.operations.Field
//import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
//import kotlin.random.Random import kotlin.random.Random
//import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
//
//class ExpressionsInterpretersBenchmark { internal class ExpressionsInterpretersBenchmark {
// private val algebra: Field<Double> = RealField private val algebra: Field<Double> = RealField
// fun functionalExpression() { fun functionalExpression() {
// val expr = algebra.expressionInField { val expr = algebra.expressionInField {
// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
// } }
//
// invokeAndSum(expr) invokeAndSum(expr)
// } }
//
// fun mstExpression() { fun mstExpression() {
// val expr = algebra.mstInField { val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// } }
//
// invokeAndSum(expr) invokeAndSum(expr)
// } }
//
// fun asmExpression() { fun asmExpression() {
// val expr = algebra.mstInField { val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// }.compile() }.compile()
//
// invokeAndSum(expr) invokeAndSum(expr)
// } }
//
// private fun invokeAndSum(expr: Expression<Double>) { private fun invokeAndSum(expr: Expression<Double>) {
// val random = Random(0) val random = Random(0)
// var sum = 0.0 var sum = 0.0
//
// repeat(1000000) { repeat(1000000) {
// sum += expr("x" to random.nextDouble()) sum += expr("x" to random.nextDouble())
// } }
//
// println(sum) println(sum)
// } }
//} }
//
//fun main() { /**
// val benchmark = ExpressionsInterpretersBenchmark() * This benchmark compares basically evaluation of simple function with MstExpression interpreter, ASM backend and
// * core FunctionalExpressions API.
// val fe = measureTimeMillis { *
// benchmark.functionalExpression() * The expected rating is:
// } *
// * 1. ASM.
// println("fe=$fe") * 2. MST.
// * 3. FE.
// val mst = measureTimeMillis { */
// benchmark.mstExpression() fun main() {
// } val benchmark = ExpressionsInterpretersBenchmark()
//
// println("mst=$mst") val fe = measureTimeMillis {
// benchmark.functionalExpression()
// val asm = measureTimeMillis { }
// benchmark.asmExpression()
// } println("fe=$fe")
//
// println("asm=$asm") val mst = measureTimeMillis {
//} benchmark.mstExpression()
}
println("mst=$mst")
val asm = measureTimeMillis {
benchmark.asmExpression()
}
println("asm=$asm")
}

View File

@ -0,0 +1,24 @@
package kscience.kmath.ast
import kscience.kmath.asm.compile
import kscience.kmath.expressions.derivative
import kscience.kmath.expressions.invoke
import kscience.kmath.expressions.symbol
import kscience.kmath.kotlingrad.differentiable
import kscience.kmath.operations.RealField
/**
* In this example, x^2-4*x-44 function is differentiated with Kotlin, and the autodiff result is compared with
* valid derivative.
*/
fun main() {
val x by symbol
val actualDerivative = MstExpression(RealField, "x^2-4*x-44".parseMath())
.differentiable()
.derivative(x)
.compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assert(actualDerivative("x" to 123.0) == expectedDerivative("x" to 123.0))
}

View File

@ -1,8 +1,10 @@
package kscience.kmath.structures package kscience.kmath.structures
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import kscience.kmath.nd4j.Nd4jArrayField
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import org.nd4j.linalg.factory.Nd4j
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -14,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) {
} }
fun main() { fun main() {
// initializing Nd4j
Nd4j.zeros(0)
val dim = 1000 val dim = 1000
val n = 1000 val n = 1000
@ -23,6 +27,8 @@ fun main() {
val specializedField = NDField.real(dim, dim) val specializedField = NDField.real(dim, dim)
//A generic boxing field. It should be used for objects, not primitives. //A generic boxing field. It should be used for objects, not primitives.
val genericField = NDField.boxing(RealField, dim, dim) val genericField = NDField.boxing(RealField, dim, dim)
// Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim)
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField { autoField {
@ -43,6 +49,13 @@ fun main() {
} }
} }
measureAndPrint("Nd4j specialized addition") {
nd4jField {
var res = one
repeat(n) { res += 1.0 as Number }
}
}
measureAndPrint("Lazy addition") { measureAndPrint("Lazy addition") {
val res = specializedField.one.mapAsync(GlobalScope) { val res = specializedField.one.mapAsync(GlobalScope) {
var c = 0.0 var c = 0.0

View File

@ -6,14 +6,14 @@ import kscience.kmath.operations.*
* [Algebra] over [MST] nodes. * [Algebra] over [MST] nodes.
*/ */
public object MstAlgebra : NumericAlgebra<MST> { public object MstAlgebra : NumericAlgebra<MST> {
override fun number(value: Number): MST = MST.Numeric(value) override fun number(value: Number): MST.Numeric = MST.Numeric(value)
override fun symbol(value: String): MST = MST.Symbolic(value) override fun symbol(value: String): MST.Symbolic = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST = override fun unaryOperation(operation: String, arg: MST): MST.Unary =
MST.Unary(operation, arg) MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MST.Binary(operation, left, right) MST.Binary(operation, left, right)
} }
@ -21,97 +21,100 @@ public object MstAlgebra : NumericAlgebra<MST> {
* [Space] over [MST] nodes. * [Space] over [MST] nodes.
*/ */
public object MstSpace : Space<MST>, NumericAlgebra<MST> { public object MstSpace : Space<MST>, NumericAlgebra<MST> {
override val zero: MST = number(0.0) override val zero: MST.Numeric by lazy { number(0.0) }
override fun number(value: Number): MST = MstAlgebra.number(value) override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
override fun symbol(value: String): MST = MstAlgebra.symbol(value) override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b) override fun add(a: MST, b: MST): MST.Binary = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k)) override fun multiply(a: MST, k: Number): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstAlgebra.binaryOperation(operation, left, right) MstAlgebra.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstAlgebra.unaryOperation(operation, arg)
} }
/** /**
* [Ring] over [MST] nodes. * [Ring] over [MST] nodes.
*/ */
public object MstRing : Ring<MST>, NumericAlgebra<MST> { public object MstRing : Ring<MST>, NumericAlgebra<MST> {
override val zero: MST override val zero: MST.Numeric
get() = MstSpace.zero get() = MstSpace.zero
override val one: MST = number(1.0)
override fun number(value: Number): MST = MstSpace.number(value) override val one: MST.Numeric by lazy { number(1.0) }
override fun symbol(value: String): MST = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstSpace.multiply(a, k) override fun number(value: Number): MST.Numeric = MstSpace.number(value)
override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
override fun multiply(a: MST, b: MST): MST.Binary = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b) override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
override fun binaryOperation(operation: String, left: MST, right: MST): MST =
MstSpace.binaryOperation(operation, left, right) MstSpace.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstAlgebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstSpace.unaryOperation(operation, arg)
} }
/** /**
* [Field] over [MST] nodes. * [Field] over [MST] nodes.
*/ */
public object MstField : Field<MST> { public object MstField : Field<MST> {
public override val zero: MST public override val zero: MST.Numeric
get() = MstRing.zero get() = MstRing.zero
public override val one: MST public override val one: MST.Numeric
get() = MstRing.one get() = MstRing.one
public override fun symbol(value: String): MST = MstRing.symbol(value) public override fun symbol(value: String): MST.Symbolic = MstRing.symbol(value)
public override fun number(value: Number): MST = MstRing.number(value) public override fun number(value: Number): MST.Numeric = MstRing.number(value)
public override fun add(a: MST, b: MST): MST = MstRing.add(a, b) public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
public override fun multiply(a: MST, k: Number): MST = MstRing.multiply(a, k) public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
public override fun multiply(a: MST, b: MST): MST = MstRing.multiply(a, b) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
public override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b) public override fun divide(a: MST, b: MST): MST.Binary = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
public override fun binaryOperation(operation: String, left: MST, right: MST): MST = public override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstRing.binaryOperation(operation, left, right) MstRing.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstRing.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstRing.unaryOperation(operation, arg)
} }
/** /**
* [ExtendedField] over [MST] nodes. * [ExtendedField] over [MST] nodes.
*/ */
public object MstExtendedField : ExtendedField<MST> { public object MstExtendedField : ExtendedField<MST> {
override val zero: MST override val zero: MST.Numeric
get() = MstField.zero get() = MstField.zero
override val one: MST override val one: MST.Numeric
get() = MstField.one get() = MstField.one
override fun symbol(value: String): MST = MstField.symbol(value) override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun number(value: Number): MST.Numeric = MstField.number(value)
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun sin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun tan(arg: MST): MST = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg) override fun cos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg) override fun tan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.TAN_OPERATION, arg)
override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg) override fun asin(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg) override fun acos(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun sinh(arg: MST): MST = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg) override fun atan(arg: MST): MST.Unary = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun cosh(arg: MST): MST = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg) override fun sinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.SINH_OPERATION, arg)
override fun tanh(arg: MST): MST = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg) override fun cosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.COSH_OPERATION, arg)
override fun asinh(arg: MST): MST = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg) override fun tanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.TANH_OPERATION, arg)
override fun acosh(arg: MST): MST = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg) override fun asinh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ASINH_OPERATION, arg)
override fun atanh(arg: MST): MST = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg) override fun acosh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ACOSH_OPERATION, arg)
override fun add(a: MST, b: MST): MST = MstField.add(a, b) override fun atanh(arg: MST): MST.Unary = unaryOperation(HyperbolicOperations.ATANH_OPERATION, arg)
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) override fun multiply(a: MST, k: Number): MST.Binary = MstField.multiply(a, k)
override fun divide(a: MST, b: MST): MST = MstField.divide(a, b) override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
override fun power(arg: MST, pow: Number): MST = binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
override fun exp(arg: MST): MST = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: MST): MST = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))
override fun exp(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.EXP_OPERATION, arg)
override fun ln(arg: MST): MST.Unary = unaryOperation(ExponentialOperations.LN_OPERATION, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST.Binary =
MstField.binaryOperation(operation, left, right) MstField.binaryOperation(operation, left, right)
override fun unaryOperation(operation: String, arg: MST): MST = MstField.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: MST): MST.Unary = MstField.unaryOperation(operation, arg)
} }

View File

@ -13,7 +13,7 @@ import kotlin.contracts.contract
* @property mst the [MST] node. * @property mst the [MST] node.
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MST) : Expression<T> { public class MstExpression<T, out A : Algebra<T>>(public val algebra: A, public val mst: MST) : Expression<T> {
private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> { private inner class InnerAlgebra(val arguments: Map<Symbol, T>) : NumericAlgebra<T> {
override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value) override fun symbol(value: String): T = arguments[StringSymbol(value)] ?: algebra.symbol(value)
override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg) override fun unaryOperation(operation: String, arg: T): T = algebra.unaryOperation(operation, arg)
@ -21,8 +21,9 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
override fun binaryOperation(operation: String, left: T, right: T): T = override fun binaryOperation(operation: String, left: T, right: T): T =
algebra.binaryOperation(operation, left, right) algebra.binaryOperation(operation, left, right)
override fun number(value: Number): T = if (algebra is NumericAlgebra) @Suppress("UNCHECKED_CAST")
algebra.number(value) override fun number(value: Number): T = if (algebra is NumericAlgebra<*>)
(algebra as NumericAlgebra<T>).number(value)
else else
error("Numeric nodes are not supported by $this") error("Numeric nodes are not supported by $this")
} }
@ -38,14 +39,14 @@ public class MstExpression<T>(public val algebra: Algebra<T>, public val mst: MS
public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst( public inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.mst(
mstAlgebra: E, mstAlgebra: E,
block: E.() -> MST, block: E.() -> MST,
): MstExpression<T> = MstExpression(this, mstAlgebra.block()) ): MstExpression<T, A> = MstExpression(this, mstAlgebra.block())
/** /**
* Builds [MstExpression] over [Space]. * Builds [MstExpression] over [Space].
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> A.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstSpace.block()) return MstExpression(this, MstSpace.block())
} }
@ -55,7 +56,7 @@ public inline fun <reified T : Any> Space<T>.mstInSpace(block: MstSpace.() -> MS
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> A.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstRing.block()) return MstExpression(this, MstRing.block())
} }
@ -65,7 +66,7 @@ public inline fun <reified T : Any> Ring<T>.mstInRing(block: MstRing.() -> MST):
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> A.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstField.block()) return MstExpression(this, MstField.block())
} }
@ -75,7 +76,7 @@ public inline fun <reified T : Any> Field<T>.mstInField(block: MstField.() -> MS
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : ExtendedField<T>> A.mstInExtendedField(block: MstExtendedField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return MstExpression(this, MstExtendedField.block()) return MstExpression(this, MstExtendedField.block())
} }
@ -85,7 +86,7 @@ public inline fun <reified T : Any> Field<T>.mstInExtendedField(block: MstExtend
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A>.mstInSpace(block: MstSpace.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInSpace(block) return algebra.mstInSpace(block)
} }
@ -95,7 +96,7 @@ public inline fun <reified T : Any, A : Space<T>> FunctionalExpressionSpace<T, A
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.mstInRing(block: MstRing.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInRing(block) return algebra.mstInRing(block)
} }
@ -105,7 +106,7 @@ public inline fun <reified T : Any, A : Ring<T>> FunctionalExpressionRing<T, A>.
* *
* @author Alexander Nozik * @author Alexander Nozik
*/ */
public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T> { public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A>.mstInField(block: MstField.() -> MST): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInField(block) return algebra.mstInField(block)
} }
@ -117,7 +118,7 @@ public inline fun <reified T : Any, A : Field<T>> FunctionalExpressionField<T, A
*/ */
public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField( public inline fun <reified T : Any, A : ExtendedField<T>> FunctionalExpressionExtendedField<T, A>.mstInExtendedField(
block: MstExtendedField.() -> MST, block: MstExtendedField.() -> MST,
): MstExpression<T> { ): MstExpression<T, A> {
contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(block, InvocationKind.EXACTLY_ONCE) }
return algebra.mstInExtendedField(block) return algebra.mstInExtendedField(block)
} }

View File

@ -69,4 +69,5 @@ public inline fun <reified T : Any> Algebra<T>.expression(mst: MST): Expression<
* *
* @author Alexander Nozik. * @author Alexander Nozik.
*/ */
public inline fun <reified T : Any> MstExpression<T>.compile(): Expression<T> = mst.compileWith(T::class.java, algebra) public inline fun <reified T : Any> MstExpression<T, Algebra<T>>.compile(): Expression<T> =
mst.compileWith(T::class.java, algebra)

View File

@ -95,10 +95,10 @@ public class DerivativeStructureField(
public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this public override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this public override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
public companion object : AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField> { public companion object :
override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double> { AutoDiffProcessor<Double, DerivativeStructure, DerivativeStructureField, Expression<Double>> {
return DerivativeStructureExpression(function) public override fun process(function: DerivativeStructureField.() -> DerivativeStructure): DifferentiableExpression<Double, Expression<Double>> =
} DerivativeStructureExpression(function)
} }
} }
@ -108,7 +108,7 @@ public class DerivativeStructureField(
*/ */
public class DerivativeStructureExpression( public class DerivativeStructureExpression(
public val function: DerivativeStructureField.() -> DerivativeStructure, public val function: DerivativeStructureField.() -> DerivativeStructure,
) : DifferentiableExpression<Double> { ) : DifferentiableExpression<Double, Expression<Double>> {
public override operator fun invoke(arguments: Map<Symbol, Double>): Double = public override operator fun invoke(arguments: Map<Symbol, Double>): Double =
DerivativeStructureField(0, arguments).function().value DerivativeStructureField(0, arguments).function().value

View File

@ -19,9 +19,8 @@ import kotlin.reflect.KClass
public operator fun PointValuePair.component1(): DoubleArray = point public operator fun PointValuePair.component1(): DoubleArray = point
public operator fun PointValuePair.component2(): Double = value public operator fun PointValuePair.component2(): Double = value
public class CMOptimizationProblem( public class CMOptimizationProblem(override val symbols: List<Symbol>, ) :
override val symbols: List<Symbol>, OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
) : OptimizationProblem<Double>, SymbolIndexer, OptimizationFeature {
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap() private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null private var optimizatorBuilder: (() -> MultivariateOptimizer)? = null
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE, public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(DEFAULT_RELATIVE_TOLERANCE,
@ -49,7 +48,7 @@ public class CMOptimizationProblem(
addOptimizationData(objectiveFunction) addOptimizationData(objectiveFunction)
} }
public override fun diffExpression(expression: DifferentiableExpression<Double>): Unit { public override fun diffExpression(expression: DifferentiableExpression<Double, Expression<Double>>) {
expression(expression) expression(expression)
val gradientFunction = ObjectiveFunctionGradient { val gradientFunction = ObjectiveFunctionGradient {
val args = it.toMap() val args = it.toMap()

View File

@ -12,7 +12,6 @@ import kscience.kmath.structures.asBuffer
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
/** /**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/ */
@ -21,7 +20,7 @@ public fun Fitting.chiSquared(
y: Buffer<Double>, y: Buffer<Double>,
yErr: Buffer<Double>, yErr: Buffer<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double> = chiSquared(DerivativeStructureField, x, y, yErr, model) ): DifferentiableExpression<Double, Expression<Double>> = chiSquared(DerivativeStructureField, x, y, yErr, model)
/** /**
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
@ -31,7 +30,7 @@ public fun Fitting.chiSquared(
y: Iterable<Double>, y: Iterable<Double>,
yErr: Iterable<Double>, yErr: Iterable<Double>,
model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure, model: DerivativeStructureField.(x: DerivativeStructure) -> DerivativeStructure,
): DifferentiableExpression<Double> = chiSquared( ): DifferentiableExpression<Double, Expression<Double>> = chiSquared(
DerivativeStructureField, DerivativeStructureField,
x.toList().asBuffer(), x.toList().asBuffer(),
y.toList().asBuffer(), y.toList().asBuffer(),
@ -39,7 +38,6 @@ public fun Fitting.chiSquared(
model model
) )
/** /**
* Optimize expression without derivatives * Optimize expression without derivatives
*/ */
@ -48,16 +46,15 @@ public fun Expression<Double>.optimize(
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) ): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
/** /**
* Optimize differentiable expression * Optimize differentiable expression
*/ */
public fun DifferentiableExpression<Double>.optimize( public fun DifferentiableExpression<Double, Expression<Double>>.optimize(
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: CMOptimizationProblem.() -> Unit, configuration: CMOptimizationProblem.() -> Unit,
): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration) ): OptimizationResult<Double> = optimizeWith(CMOptimizationProblem, symbols = symbols, configuration)
public fun DifferentiableExpression<Double>.minimize( public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
vararg startPoint: Pair<Symbol, Double>, vararg startPoint: Pair<Symbol, Double>,
configuration: CMOptimizationProblem.() -> Unit = {}, configuration: CMOptimizationProblem.() -> Unit = {},
): OptimizationResult<Double> { ): OptimizationResult<Double> {

View File

@ -47,14 +47,17 @@ internal class OptimizeTest {
val sigma = 1.0 val sigma = 1.0
val generator = Distribution.normal(0.0, sigma) val generator = Distribution.normal(0.0, sigma)
val chain = generator.sample(RandomGenerator.default(112667)) val chain = generator.sample(RandomGenerator.default(112667))
val x = (1..100).map { it.toDouble() } val x = (1..100).map(Int::toDouble)
val y = x.map { it ->
val y = x.map {
it.pow(2) + it + 1 + chain.nextDouble() it.pow(2) + it + 1 + chain.nextDouble()
} }
val yErr = x.map { sigma }
val chi2 = Fitting.chiSquared(x, y, yErr) { x -> val yErr = List(x.size) { sigma }
val chi2 = Fitting.chiSquared(x, y, yErr) { x1 ->
val cWithDefault = bindOrNull(c) ?: one val cWithDefault = bindOrNull(c) ?: one
bind(a) * x.pow(2) + bind(b) * x + cWithDefault bind(a) * x1.pow(2) + bind(b) * x1 + cWithDefault
} }
val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0) val result = chi2.minimize(a to 1.5, b to 0.9, c to 1.0)

View File

@ -1,29 +1,40 @@
package kscience.kmath.expressions package kscience.kmath.expressions
/** /**
* An expression that provides derivatives * Represents expression which structure can be differentiated.
*
* @param T the type this expression takes as argument and returns.
* @param R the type of expression this expression can be differentiated to.
*/ */
public interface DifferentiableExpression<T> : Expression<T> { public interface DifferentiableExpression<T, out R : Expression<T>> : Expression<T> {
public fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? /**
* Differentiates this expression by ordered collection of [symbols].
*
* @param symbols the symbols.
* @return the derivative or `null`.
*/
public fun derivativeOrNull(symbols: List<Symbol>): R?
} }
public fun <T> DifferentiableExpression<T>.derivative(symbols: List<Symbol>): Expression<T> = public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(symbols: List<Symbol>): R =
derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided") derivativeOrNull(symbols) ?: error("Derivative by symbols $symbols not provided")
public fun <T> DifferentiableExpression<T>.derivative(vararg symbols: Symbol): Expression<T> = public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(vararg symbols: Symbol): R =
derivative(symbols.toList()) derivative(symbols.toList())
public fun <T> DifferentiableExpression<T>.derivative(name: String): Expression<T> = public fun <T, R : Expression<T>> DifferentiableExpression<T, R>.derivative(name: String): R =
derivative(StringSymbol(name)) derivative(StringSymbol(name))
/** /**
* A [DifferentiableExpression] that defines only first derivatives * A [DifferentiableExpression] that defines only first derivatives
*/ */
public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T> { public abstract class FirstDerivativeExpression<T, R : Expression<T>> : DifferentiableExpression<T,R> {
/**
* Returns first derivative of this expression by given [symbol].
*/
public abstract fun derivativeOrNull(symbol: Symbol): R?
public abstract fun derivativeOrNull(symbol: Symbol): Expression<T>? public final override fun derivativeOrNull(symbols: List<Symbol>): R? {
public override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? {
val dSymbol = symbols.firstOrNull() ?: return null val dSymbol = symbols.firstOrNull() ?: return null
return derivativeOrNull(dSymbol) return derivativeOrNull(dSymbol)
} }
@ -32,6 +43,6 @@ public abstract class FirstDerivativeExpression<T> : DifferentiableExpression<T>
/** /**
* A factory that converts an expression in autodiff variables to a [DifferentiableExpression] * A factory that converts an expression in autodiff variables to a [DifferentiableExpression]
*/ */
public interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>> { public fun interface AutoDiffProcessor<T : Any, I : Any, A : ExpressionAlgebra<T, I>, out R : Expression<T>> {
public fun process(function: A.() -> I): DifferentiableExpression<T> public fun process(function: A.() -> I): DifferentiableExpression<T, R>
} }

View File

@ -22,7 +22,9 @@ public inline class StringSymbol(override val identity: String) : Symbol {
} }
/** /**
* An elementary function that could be invoked on a map of arguments * An elementary function that could be invoked on a map of arguments.
*
* @param T the type this expression takes as argument and returns.
*/ */
public fun interface Expression<T> { public fun interface Expression<T> {
/** /**

View File

@ -68,7 +68,7 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
): DerivationResult<T> { ): DerivationResult<T> {
contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(body, InvocationKind.EXACTLY_ONCE) }
return SimpleAutoDiffField(this, bindings).derivate(body) return SimpleAutoDiffField(this, bindings).differentiate(body)
} }
public fun <T : Any, F : Field<T>> F.simpleAutoDiff( public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
@ -83,12 +83,21 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F, public val context: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> { ) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
public override val zero: AutoDiffValue<T>
get() = const(context.zero)
public override val one: AutoDiffValue<T>
get() = const(context.one)
// this stack contains pairs of blocks and values to apply them to // this stack contains pairs of blocks and values to apply them to
private var stack: Array<Any?> = arrayOfNulls<Any?>(8) private var stack: Array<Any?> = arrayOfNulls<Any?>(8)
private var sp: Int = 0 private var sp: Int = 0
private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf() private val derivatives: MutableMap<AutoDiffValue<T>, T> = hashMapOf()
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate {
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
/** /**
* Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result * Differentiable variable with value and derivative of differentiation ([simpleAutoDiff]) result
* with respect to this variable. * with respect to this variable.
@ -106,11 +115,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
override fun hashCode(): Int = identity.hashCode() override fun hashCode(): Int = identity.hashCode()
} }
private val bindings: Map<String, AutoDiffVariableWithDerivative<T>> = bindings.entries.associate { public override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
it.key.identity to AutoDiffVariableWithDerivative(it.key.identity, it.value, context.zero)
}
override fun bindOrNull(symbol: Symbol): AutoDiffValue<T>? = bindings[symbol.identity]
private fun getDerivative(variable: AutoDiffValue<T>): T = private fun getDerivative(variable: AutoDiffValue<T>): T =
(variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero (variable as? AutoDiffVariableWithDerivative)?.d ?: derivatives[variable] ?: context.zero
@ -119,7 +124,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value if (variable is AutoDiffVariableWithDerivative) variable.d = value else derivatives[variable] = value
} }
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
private fun runBackwardPass() { private fun runBackwardPass() {
while (sp > 0) { while (sp > 0) {
@ -129,9 +133,6 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
} }
} }
override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one)
override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value) override fun const(value: T): AutoDiffValue<T> = AutoDiffValue(value)
/** /**
@ -165,7 +166,7 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
} }
internal fun derivate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> { internal fun differentiate(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DerivationResult<T> {
val result = function() val result = function()
result.d = context.one // computing derivative w.r.t result result.d = context.one // computing derivative w.r.t result
runBackwardPass() runBackwardPass()
@ -174,41 +175,41 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
// Overloads for Double constants // Overloads for Double constants
override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> = public override operator fun Number.plus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@plus.toDouble() * one + b.value }) { z -> derive(const { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d b.d += z.d
} }
override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this) public override operator fun AutoDiffValue<T>.plus(b: Number): AutoDiffValue<T> = b.plus(this)
override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> = public override operator fun Number.minus(b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d } derive(const { this@minus.toDouble() * one - b.value }) { z -> b.d -= z.d }
override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> = public override operator fun AutoDiffValue<T>.minus(b: Number): AutoDiffValue<T> =
derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d } derive(const { this@minus.value - one * b.toDouble() }) { z -> this@minus.d += z.d }
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = public override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z -> derive(const { a.value + b.value }) { z ->
a.d += z.d a.d += z.d
b.d += z.d b.d += z.d
} }
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = public override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z -> derive(const { a.value * b.value }) { z ->
a.d += z.d * b.value a.d += z.d * b.value
b.d += z.d * a.value b.d += z.d * a.value
} }
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = public override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z -> derive(const { a.value / b.value }) { z ->
a.d += z.d / b.value a.d += z.d / b.value
b.d -= z.d * a.value / (b.value * b.value) b.d -= z.d * a.value / (b.value * b.value)
} }
override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> = public override fun multiply(a: AutoDiffValue<T>, k: Number): AutoDiffValue<T> =
derive(const { k.toDouble() * a.value }) { z -> derive(const { k.toDouble() * a.value }) { z ->
a.d += z.d * k.toDouble() a.d += z.d * k.toDouble()
} }
@ -220,15 +221,15 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public class SimpleAutoDiffExpression<T : Any, F : Field<T>>( public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F, public val field: F,
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>, public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T>() { ) : FirstDerivativeExpression<T, Expression<T>>() {
public override operator fun invoke(arguments: Map<Symbol, T>): T { public override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
return SimpleAutoDiffField(field, arguments).function().value return SimpleAutoDiffField(field, arguments).function().value
} }
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments -> public override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = SimpleAutoDiffField(field, arguments).derivate(function) val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
derivationResult.derivative(symbol) derivationResult.derivative(symbol)
} }
} }
@ -236,13 +237,10 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
/** /**
* Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression] * Generate [AutoDiffProcessor] for [SimpleAutoDiffExpression]
*/ */
public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> { public fun <T : Any, F : Field<T>> simpleAutoDiff(field: F): AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>, Expression<T>> =
return object : AutoDiffProcessor<T, AutoDiffValue<T>, SimpleAutoDiffField<T, F>> { AutoDiffProcessor { function ->
override fun process(function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>): DifferentiableExpression<T> { SimpleAutoDiffExpression(field, function)
return SimpleAutoDiffExpression(field, function)
}
} }
}
// Extensions for differentiation of various basic mathematical functions // Extensions for differentiation of various basic mathematical functions
@ -392,4 +390,4 @@ public class SimpleAutoDiffExtendedField<T : Any, F : ExtendedField<T>>(
public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> = public override fun atanh(arg: AutoDiffValue<T>): AutoDiffValue<T> =
(this as SimpleAutoDiffField<T, F>).atanh(arg) (this as SimpleAutoDiffField<T, F>).atanh(arg)
} }

View File

@ -0,0 +1,9 @@
plugins {
id("ru.mipt.npm.jvm")
}
dependencies {
implementation("com.github.breandan:kaliningraph:0.1.2")
implementation("com.github.breandan:kotlingrad:0.3.7")
api(project(":kmath-ast"))
}

View File

@ -0,0 +1,53 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.SFun
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.expressions.DifferentiableExpression
import kscience.kmath.expressions.Symbol
import kscience.kmath.operations.NumericAlgebra
/**
* Represents wrapper of [MstExpression] implementing [DifferentiableExpression].
*
* The principle of this API is converting the [mst] to an [SFun], differentiating it with Kotlin, then converting
* [SFun] back to [MST].
*
* @param T the type of number.
* @param A the [NumericAlgebra] of [T].
* @property expr the underlying [MstExpression].
*/
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
/**
* The [MstExpression.algebra] of [expr].
*/
public val algebra: A
get() = expr.algebra
/**
* The [MstExpression.mst] of [expr].
*/
public val mst: MST
get() = expr.mst
public override fun invoke(arguments: Map<Symbol, T>): T = expr(arguments)
public override fun derivativeOrNull(symbols: List<Symbol>): MstExpression<T, A> = MstExpression(
algebra,
symbols.map(Symbol::identity)
.map(MstAlgebra::symbol)
.map { it.toSVar<KMathNumber<T, A>>() }
.fold(mst.toSFun(), SFun<KMathNumber<T, A>>::d)
.toMst(),
)
}
/**
* Wraps this [MstExpression] into [DifferentiableMstExpression].
*/
public fun <T : Number, A : NumericAlgebra<T>> MstExpression<T, A>.differentiable(): DifferentiableMstExpression<T, A> =
DifferentiableMstExpression(this)

View File

@ -0,0 +1,18 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.RealNumber
import edu.umontreal.kotlingrad.experimental.SConst
import kscience.kmath.operations.NumericAlgebra
/**
* Implements [RealNumber] by delegating its functionality to [NumericAlgebra].
*
* @param T the type of number.
* @param A the [NumericAlgebra] of [T].
* @property algebra the algebra.
* @param value the value of this number.
*/
public class KMathNumber<T, A>(public val algebra: A, value: T) :
RealNumber<KMathNumber<T, A>, T>(value) where T : Number, A : NumericAlgebra<T> {
public override fun wrap(number: Number): SConst<KMathNumber<T, A>> = SConst(algebra.number(number))
}

View File

@ -0,0 +1,124 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.ast.MST
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExtendedField
import kscience.kmath.ast.MstExtendedField.unaryMinus
import kscience.kmath.operations.*
/**
* Maps [SVar] to [MST.Symbolic] directly.
*
* @receiver the variable.
* @return a node.
*/
public fun <X : SFun<X>> SVar<X>.toMst(): MST.Symbolic = MstAlgebra.symbol(name)
/**
* Maps [SVar] to [MST.Numeric] directly.
*
* @receiver the constant.
* @return a node.
*/
public fun <X : SFun<X>> SConst<X>.toMst(): MST.Numeric = MstAlgebra.number(doubleValue)
/**
* Maps [SFun] objects to [MST]. Some unsupported operations like [Derivative] are bound and converted then.
* [Power] operation is limited to constant right-hand side arguments.
*
* Detailed mapping is:
*
* - [SVar] -> [MstExtendedField.symbol];
* - [SConst] -> [MstExtendedField.number];
* - [Sum] -> [MstExtendedField.add];
* - [Prod] -> [MstExtendedField.multiply];
* - [Power] -> [MstExtendedField.power] (limited to constant exponents only);
* - [Negative] -> [MstExtendedField.unaryMinus];
* - [Log] -> [MstExtendedField.ln] (left) / [MstExtendedField.ln] (right);
* - [Sine] -> [MstExtendedField.sin];
* - [Cosine] -> [MstExtendedField.cos];
* - [Tangent] -> [MstExtendedField.tan];
* - [DProd] is vector operation, and it is requested to be evaluated;
* - [SComposition] is also requested to be evaluated eagerly;
* - [VSumAll] is requested to be evaluated;
* - [Derivative] is requested to be evaluated.
*
* @receiver the scalar function.
* @return a node.
*/
public fun <X : SFun<X>> SFun<X>.toMst(): MST = MstExtendedField {
when (this@toMst) {
is SVar -> toMst()
is SConst -> toMst()
is Sum -> left.toMst() + right.toMst()
is Prod -> left.toMst() * right.toMst()
is Power -> left.toMst() pow ((right as? SConst<*>)?.doubleValue ?: (right() as SConst<*>).doubleValue)
is Negative -> -input.toMst()
is Log -> ln(left.toMst()) / ln(right.toMst())
is Sine -> sin(input.toMst())
is Cosine -> cos(input.toMst())
is Tangent -> tan(input.toMst())
is DProd -> this@toMst().toMst()
is SComposition -> this@toMst().toMst()
is VSumAll<X, *> -> this@toMst().toMst()
is Derivative -> this@toMst().toMst()
}
}
/**
* Maps [MST.Numeric] to [SConst] directly.
*
* @receiver the node.
* @return a new constant.
*/
public fun <X : SFun<X>> MST.Numeric.toSConst(): SConst<X> = SConst(value)
/**
* Maps [MST.Symbolic] to [SVar] directly.
*
* @receiver the node.
* @param proto the prototype instance.
* @return a new variable.
*/
internal fun <X : SFun<X>> MST.Symbolic.toSVar(): SVar<X> = SVar(value)
/**
* Maps [MST] objects to [SFun]. Unsupported operations throw [IllegalStateException].
*
* Detailed mapping is:
*
* - [MST.Numeric] -> [SConst];
* - [MST.Symbolic] -> [SVar];
* - [MST.Unary] -> [Negative], [Sine], [Cosine], [Tangent], [Power], [Log];
* - [MST.Binary] -> [Sum], [Prod], [Power].
*
* @receiver the node.
* @param proto the prototype instance.
* @return a scalar function.
*/
public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is MST.Numeric -> toSConst()
is MST.Symbolic -> toSVar()
is MST.Unary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> +value.toSFun<X>()
SpaceOperations.MINUS_OPERATION -> -value.toSFun<X>()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
PowerOperations.SQRT_OPERATION -> sqrt(value.toSFun())
ExponentialOperations.EXP_OPERATION -> exp(value.toSFun())
ExponentialOperations.LN_OPERATION -> value.toSFun<X>().ln()
else -> error("Unary operation $operation not defined in $this")
}
is MST.Binary -> when (operation) {
SpaceOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
SpaceOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this")
}
}

View File

@ -0,0 +1,64 @@
package kscience.kmath.kotlingrad
import edu.umontreal.kotlingrad.experimental.*
import kscience.kmath.asm.compile
import kscience.kmath.ast.MstAlgebra
import kscience.kmath.ast.MstExpression
import kscience.kmath.ast.parseMath
import kscience.kmath.expressions.invoke
import kscience.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import kotlin.test.fail
internal class AdaptingTests {
@Test
fun symbol() {
val c1 = MstAlgebra.symbol("x")
assertTrue(c1.toSVar<KMathNumber<Double, RealField>>().name == "x")
val c2 = "kitten".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (c2 is SVar) assertTrue(c2.name == "kitten") else fail()
}
@Test
fun number() {
val c1 = MstAlgebra.number(12354324)
assertTrue(c1.toSConst<DReal>().doubleValue == 12354324.0)
val c2 = "0.234".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (c2 is SConst) assertTrue(c2.doubleValue == 0.234) else fail()
val c3 = "1e-3".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (c3 is SConst) assertEquals(0.001, c3.value) else fail()
}
@Test
fun simpleFunctionShape() {
val linear = "2*x+16".parseMath().toSFun<KMathNumber<Double, RealField>>()
if (linear !is Sum) fail()
if (linear.left !is Prod) fail()
if (linear.right !is SConst) fail()
}
@Test
fun simpleFunctionDerivative() {
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
val quadratic = "x^2-4*x-44".parseMath().toSFun<KMathNumber<Double, RealField>>()
val actualDerivative = MstExpression(RealField, quadratic.d(x).toMst()).compile()
val expectedDerivative = MstExpression(RealField, "2*x-4".parseMath()).compile()
assertEquals(actualDerivative("x" to 123.0), expectedDerivative("x" to 123.0))
}
@Test
fun moreComplexDerivative() {
val x = MstAlgebra.symbol("x").toSVar<KMathNumber<Double, RealField>>()
val composition = "-sqrt(sin(x^2)-cos(x)^2-16*x)".parseMath().toSFun<KMathNumber<Double, RealField>>()
val actualDerivative = MstExpression(RealField, composition.d(x).toMst()).compile()
val expectedDerivative = MstExpression(
RealField,
"-(2*x*cos(x^2)+2*sin(x)*cos(x)-16)/(2*sqrt(sin(x^2)-16*x-cos(x)^2))".parseMath()
).compile()
assertEquals(actualDerivative("x" to 0.1), expectedDerivative("x" to 0.1))
}
}

View File

@ -126,6 +126,36 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
check(b) check(b)
return b.ndArray.rsub(this).wrap() return b.ndArray.rsub(this).wrap()
} }
public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/**
* Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
*/
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
}
}
} }
/** /**
@ -145,6 +175,37 @@ public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd
check(b) check(b)
return b.ndArray.rdiv(this).wrap() return b.ndArray.rdiv(this).wrap()
} }
public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
private val realNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, RealNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing<Double, RealField> =
realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
}
}
} }
/** /**

View File

@ -1,4 +1,6 @@
plugins { id("ru.mipt.npm.mpp") } plugins {
id("ru.mipt.npm.mpp")
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {

View File

@ -12,16 +12,18 @@ public object Fitting {
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation * Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
*/ */
public fun <T : Any, I : Any, A> chiSquared( public fun <T : Any, I : Any, A> chiSquared(
autoDiff: AutoDiffProcessor<T, I, A>, autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
x: Buffer<T>, x: Buffer<T>,
y: Buffer<T>, y: Buffer<T>,
yErr: Buffer<T>, yErr: Buffer<T>,
model: A.(I) -> I, model: A.(I) -> I,
): DifferentiableExpression<T> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> { ): DifferentiableExpression<T, Expression<T>> where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
require(x.size == y.size) { "X and y buffers should be of the same size" } require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
return autoDiff.process { return autoDiff.process {
var sum = zero var sum = zero
x.indices.forEach { x.indices.forEach {
val xValue = const(x[it]) val xValue = const(x[it])
val yValue = const(y[it]) val yValue = const(y[it])
@ -29,6 +31,7 @@ public object Fitting {
val modelValue = model(xValue) val modelValue = model(xValue)
sum += ((yValue - modelValue) / yErrValue).pow(2) sum += ((yValue - modelValue) / yErrValue).pow(2)
} }
sum sum
} }
} }
@ -45,6 +48,7 @@ public object Fitting {
): Expression<Double> { ): Expression<Double> {
require(x.size == y.size) { "X and y buffers should be of the same size" } require(x.size == y.size) { "X and y buffers should be of the same size" }
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" } require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
return Expression { arguments -> return Expression { arguments ->
x.indices.sumByDouble { x.indices.sumByDouble {
val xValue = x[it] val xValue = x[it]
@ -56,4 +60,4 @@ public object Fitting {
} }
} }
} }
} }

View File

@ -27,17 +27,17 @@ public interface OptimizationProblem<T : Any> {
/** /**
* Define the initial guess for the optimization problem * Define the initial guess for the optimization problem
*/ */
public fun initialGuess(map: Map<Symbol, T>): Unit public fun initialGuess(map: Map<Symbol, T>)
/** /**
* Set an objective function expression * Set an objective function expression
*/ */
public fun expression(expression: Expression<T>): Unit public fun expression(expression: Expression<T>)
/** /**
* Set a differentiable expression as objective function as function and gradient provider * Set a differentiable expression as objective function as function and gradient provider
*/ */
public fun diffExpression(expression: DifferentiableExpression<T>): Unit public fun diffExpression(expression: DifferentiableExpression<T, Expression<T>>)
/** /**
* Update the problem from previous optimization run * Update the problem from previous optimization run
@ -50,9 +50,8 @@ public interface OptimizationProblem<T : Any> {
public fun optimize(): OptimizationResult<T> public fun optimize(): OptimizationResult<T>
} }
public interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> { public fun interface OptimizationProblemFactory<T : Any, out P : OptimizationProblem<T>> {
public fun build(symbols: List<Symbol>): P public fun build(symbols: List<Symbol>): P
} }
public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke( public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFactory<T, P>.invoke(
@ -60,7 +59,6 @@ public operator fun <T : Any, P : OptimizationProblem<T>> OptimizationProblemFac
block: P.() -> Unit, block: P.() -> Unit,
): P = build(symbols).apply(block) ): P = build(symbols).apply(block)
/** /**
* Optimize expression without derivatives using specific [OptimizationProblemFactory] * Optimize expression without derivatives using specific [OptimizationProblemFactory]
*/ */
@ -78,7 +76,7 @@ public fun <T : Any, F : OptimizationProblem<T>> Expression<T>.optimizeWith(
/** /**
* Optimize differentiable expression using specific [OptimizationProblemFactory] * Optimize differentiable expression using specific [OptimizationProblemFactory]
*/ */
public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.optimizeWith( public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T, Expression<T>>.optimizeWith(
factory: OptimizationProblemFactory<T, F>, factory: OptimizationProblemFactory<T, F>,
vararg symbols: Symbol, vararg symbols: Symbol,
configuration: F.() -> Unit, configuration: F.() -> Unit,
@ -88,4 +86,3 @@ public fun <T : Any, F : OptimizationProblem<T>> DifferentiableExpression<T>.op
problem.diffExpression(this) problem.diffExpression(this)
return problem.optimize() return problem.optimize()
} }

View File

@ -1,4 +1,6 @@
plugins { id("ru.mipt.npm.jvm") } plugins {
id("ru.mipt.npm.jvm")
}
description = "Binding for https://github.com/JetBrains-Research/viktor" description = "Binding for https://github.com/JetBrains-Research/viktor"

View File

@ -1,13 +1,11 @@
pluginManagement { pluginManagement {
repositories { repositories {
mavenLocal()
jcenter()
gradlePluginPortal() gradlePluginPortal()
jcenter()
maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlin-eap")
maven("https://dl.bintray.com/mipt-npm/kscience") maven("https://dl.bintray.com/mipt-npm/kscience")
maven("https://dl.bintray.com/mipt-npm/dev") maven("https://dl.bintray.com/mipt-npm/dev")
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
maven("https://dl.bintray.com/kotlin/kotlin-dev/")
} }
val toolsVersion = "0.6.4-dev-1.4.20-M2" val toolsVersion = "0.6.4-dev-1.4.20-M2"
@ -41,5 +39,6 @@ include(
":kmath-geometry", ":kmath-geometry",
":kmath-ast", ":kmath-ast",
":kmath-ejml", ":kmath-ejml",
":kmath-kotlingrad",
":examples" ":examples"
) )