Initial Optimization API
This commit is contained in:
parent
e216fd50f5
commit
cd05ca6e95
@ -11,10 +11,7 @@ allprojects {
|
|||||||
jcenter()
|
jcenter()
|
||||||
maven("https://clojars.org/repo")
|
maven("https://clojars.org/repo")
|
||||||
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
maven("https://dl.bintray.com/egor-bogomolov/astminer/")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
maven("https://dl.bintray.com/hotkeytlt/maven")
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
|
||||||
maven("https://dl.bintray.com/mipt-npm/dev")
|
|
||||||
maven("https://dl.bintray.com/mipt-npm/kscience")
|
|
||||||
maven("https://jitpack.io")
|
maven("https://jitpack.io")
|
||||||
maven("http://logicrunch.research.it.uu.se/maven/")
|
maven("http://logicrunch.research.it.uu.se/maven/")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
@ -15,10 +15,7 @@ import space.kscience.kmath.expressions.SymbolIndexer
|
|||||||
import space.kscience.kmath.expressions.derivative
|
import space.kscience.kmath.expressions.derivative
|
||||||
import space.kscience.kmath.misc.Symbol
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.optimization.FunctionOptimization
|
import space.kscience.kmath.optimization.*
|
||||||
import space.kscience.kmath.optimization.OptimizationFeature
|
|
||||||
import space.kscience.kmath.optimization.OptimizationProblemFactory
|
|
||||||
import space.kscience.kmath.optimization.OptimizationResult
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
@ -27,7 +24,8 @@ public operator fun PointValuePair.component2(): Double = value
|
|||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class CMOptimization(
|
public class CMOptimization(
|
||||||
override val symbols: List<Symbol>,
|
override val symbols: List<Symbol>,
|
||||||
) : FunctionOptimization<Double>, SymbolIndexer, OptimizationFeature {
|
) : FunctionOptimization<Double>, NoDerivFunctionOptimization<Double>, SymbolIndexer, OptimizationFeature {
|
||||||
|
|
||||||
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
private val optimizationData: HashMap<KClass<out OptimizationData>, OptimizationData> = HashMap()
|
||||||
private var optimizerBuilder: (() -> MultivariateOptimizer)? = null
|
private var optimizerBuilder: (() -> MultivariateOptimizer)? = null
|
||||||
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(
|
public var convergenceChecker: ConvergenceChecker<PointValuePair> = SimpleValueChecker(
|
||||||
@ -36,6 +34,12 @@ public class CMOptimization(
|
|||||||
DEFAULT_MAX_ITER
|
DEFAULT_MAX_ITER
|
||||||
)
|
)
|
||||||
|
|
||||||
|
override var maximize: Boolean
|
||||||
|
get() = optimizationData[GoalType::class] == GoalType.MAXIMIZE
|
||||||
|
set(value) {
|
||||||
|
optimizationData[GoalType::class] = if (value) GoalType.MAXIMIZE else GoalType.MINIMIZE
|
||||||
|
}
|
||||||
|
|
||||||
public fun addOptimizationData(data: OptimizationData) {
|
public fun addOptimizationData(data: OptimizationData) {
|
||||||
optimizationData[data::class] = data
|
optimizationData[data::class] = data
|
||||||
}
|
}
|
||||||
@ -50,7 +54,7 @@ public class CMOptimization(
|
|||||||
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
addOptimizationData(InitialGuess(map.toDoubleArray()))
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun expression(expression: Expression<Double>): Unit {
|
public override fun function(expression: Expression<Double>): Unit {
|
||||||
val objectiveFunction = ObjectiveFunction {
|
val objectiveFunction = ObjectiveFunction {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
expression(args)
|
expression(args)
|
||||||
@ -58,8 +62,8 @@ public class CMOptimization(
|
|||||||
addOptimizationData(objectiveFunction)
|
addOptimizationData(objectiveFunction)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun diffExpression(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
public override fun diffFunction(expression: DifferentiableExpression<Double, Expression<Double>>) {
|
||||||
expression(expression)
|
function(expression)
|
||||||
val gradientFunction = ObjectiveFunctionGradient {
|
val gradientFunction = ObjectiveFunctionGradient {
|
||||||
val args = it.toMap()
|
val args = it.toMap()
|
||||||
DoubleArray(symbols.size) { index ->
|
DoubleArray(symbols.size) { index ->
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
package space.kscience.kmath.commons.optimization
|
package space.kscience.kmath.commons.optimization
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType
|
|
||||||
import space.kscience.kmath.commons.expressions.DerivativeStructureField
|
import space.kscience.kmath.commons.expressions.DerivativeStructureField
|
||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.misc.Symbol
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.optimization.FunctionOptimization
|
import space.kscience.kmath.optimization.FunctionOptimization
|
||||||
import space.kscience.kmath.optimization.OptimizationResult
|
import space.kscience.kmath.optimization.OptimizationResult
|
||||||
|
import space.kscience.kmath.optimization.noDerivOptimizeWith
|
||||||
import space.kscience.kmath.optimization.optimizeWith
|
import space.kscience.kmath.optimization.optimizeWith
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
@ -44,7 +44,7 @@ public fun FunctionOptimization.Companion.chiSquared(
|
|||||||
public fun Expression<Double>.optimize(
|
public fun Expression<Double>.optimize(
|
||||||
vararg symbols: Symbol,
|
vararg symbols: Symbol,
|
||||||
configuration: CMOptimization.() -> Unit,
|
configuration: CMOptimization.() -> Unit,
|
||||||
): OptimizationResult<Double> = optimizeWith(CMOptimization, symbols = symbols, configuration)
|
): OptimizationResult<Double> = noDerivOptimizeWith(CMOptimization, symbols = symbols, configuration)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize differentiable expression
|
* Optimize differentiable expression
|
||||||
@ -58,10 +58,11 @@ public fun DifferentiableExpression<Double, Expression<Double>>.minimize(
|
|||||||
vararg startPoint: Pair<Symbol, Double>,
|
vararg startPoint: Pair<Symbol, Double>,
|
||||||
configuration: CMOptimization.() -> Unit = {},
|
configuration: CMOptimization.() -> Unit = {},
|
||||||
): OptimizationResult<Double> {
|
): OptimizationResult<Double> {
|
||||||
require(startPoint.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
val symbols = startPoint.map { it.first }.toTypedArray()
|
||||||
val problem = CMOptimization(startPoint.map { it.first }).apply(configuration)
|
return optimize(*symbols){
|
||||||
problem.diffExpression(this)
|
maximize = false
|
||||||
problem.initialGuess(startPoint.toMap())
|
initialGuess(startPoint.toMap())
|
||||||
problem.goal(GoalType.MINIMIZE)
|
diffFunction(this@minimize)
|
||||||
return problem.optimize()
|
configuration()
|
||||||
|
}
|
||||||
}
|
}
|
@ -1,3 +1,18 @@
|
|||||||
|
public final class space/kscience/kmath/data/ColumnarDataKt {
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/data/XYColumnarData$DefaultImpls {
|
||||||
|
public static fun get (Lspace/kscience/kmath/data/XYColumnarData;Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/structures/Buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/data/XYColumnarDataKt {
|
||||||
|
public static synthetic fun asXYData$default (Lspace/kscience/kmath/nd/Structure2D;IIILjava/lang/Object;)Lspace/kscience/kmath/data/XYColumnarData;
|
||||||
|
}
|
||||||
|
|
||||||
|
public final class space/kscience/kmath/data/XYZColumnarData$DefaultImpls {
|
||||||
|
public static fun get (Lspace/kscience/kmath/data/XYZColumnarData;Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/structures/Buffer;
|
||||||
|
}
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/domains/Domain {
|
public abstract interface class space/kscience/kmath/domains/Domain {
|
||||||
public abstract fun contains (Lspace/kscience/kmath/structures/Buffer;)Z
|
public abstract fun contains (Lspace/kscience/kmath/structures/Buffer;)Z
|
||||||
public abstract fun getDimension ()I
|
public abstract fun getDimension ()I
|
||||||
@ -603,15 +618,6 @@ public final class space/kscience/kmath/misc/CumulativeKt {
|
|||||||
public static final fun cumulativeSumOfLong (Lkotlin/sequences/Sequence;)Lkotlin/sequences/Sequence;
|
public static final fun cumulativeSumOfLong (Lkotlin/sequences/Sequence;)Lkotlin/sequences/Sequence;
|
||||||
}
|
}
|
||||||
|
|
||||||
public final class space/kscience/kmath/misc/NDStructureColumn : space/kscience/kmath/structures/Buffer {
|
|
||||||
public fun <init> (Lspace/kscience/kmath/nd/Structure2D;I)V
|
|
||||||
public fun get (I)Ljava/lang/Object;
|
|
||||||
public final fun getColumn ()I
|
|
||||||
public fun getSize ()I
|
|
||||||
public final fun getStructure ()Lspace/kscience/kmath/nd/Structure2D;
|
|
||||||
public fun iterator ()Ljava/util/Iterator;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/misc/StringSymbol : space/kscience/kmath/misc/Symbol {
|
public final class space/kscience/kmath/misc/StringSymbol : space/kscience/kmath/misc/Symbol {
|
||||||
public static final synthetic fun box-impl (Ljava/lang/String;)Lspace/kscience/kmath/misc/StringSymbol;
|
public static final synthetic fun box-impl (Ljava/lang/String;)Lspace/kscience/kmath/misc/StringSymbol;
|
||||||
public static fun constructor-impl (Ljava/lang/String;)Ljava/lang/String;
|
public static fun constructor-impl (Ljava/lang/String;)Ljava/lang/String;
|
||||||
@ -644,17 +650,6 @@ public final class space/kscience/kmath/misc/SymbolKt {
|
|||||||
public abstract interface annotation class space/kscience/kmath/misc/UnstableKMathAPI : java/lang/annotation/Annotation {
|
public abstract interface annotation class space/kscience/kmath/misc/UnstableKMathAPI : java/lang/annotation/Annotation {
|
||||||
}
|
}
|
||||||
|
|
||||||
public final class space/kscience/kmath/misc/XYPointSet$DefaultImpls {
|
|
||||||
public static fun get (Lspace/kscience/kmath/misc/XYPointSet;Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/structures/Buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/misc/XYPointSetKt {
|
|
||||||
}
|
|
||||||
|
|
||||||
public final class space/kscience/kmath/misc/XYZPointSet$DefaultImpls {
|
|
||||||
public static fun get (Lspace/kscience/kmath/misc/XYZPointSet;Lspace/kscience/kmath/misc/Symbol;)Lspace/kscience/kmath/structures/Buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract interface class space/kscience/kmath/nd/AlgebraND {
|
public abstract interface class space/kscience/kmath/nd/AlgebraND {
|
||||||
public static final field Companion Lspace/kscience/kmath/nd/AlgebraND$Companion;
|
public static final field Companion Lspace/kscience/kmath/nd/AlgebraND$Companion;
|
||||||
public abstract fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND;
|
public abstract fun combine (Lspace/kscience/kmath/nd/StructureND;Lspace/kscience/kmath/nd/StructureND;Lkotlin/jvm/functions/Function3;)Lspace/kscience/kmath/nd/StructureND;
|
||||||
|
@ -0,0 +1,34 @@
|
|||||||
|
package space.kscience.kmath.data
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.nd.Structure2D
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A column-based data set with all columns of the same size (not necessary fixed in time).
|
||||||
|
* The column could be retrieved by a [get] operation.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface ColumnarData<out T> {
|
||||||
|
public val size: Int
|
||||||
|
|
||||||
|
public operator fun get(symbol: Symbol): Buffer<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A zero-copy method to represent a [Structure2D] as a two-column x-y data.
|
||||||
|
* There could more than two columns in the structure.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> Structure2D<T>.asColumnarData(mapping: Map<Symbol, Int>): ColumnarData<T> {
|
||||||
|
require(shape[1] >= mapping.maxOf { it.value }) { "Column index out of bounds" }
|
||||||
|
return object : ColumnarData<T> {
|
||||||
|
override val size: Int get() = shape[0]
|
||||||
|
override fun get(symbol: Symbol): Buffer<T> {
|
||||||
|
val index = mapping[symbol] ?: error("No column mapping for symbol $symbol")
|
||||||
|
return columns[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,55 @@
|
|||||||
|
package space.kscience.kmath.data
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.nd.Structure2D
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import kotlin.math.max
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The buffer of X values.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface XYColumnarData<T, out X : T, out Y : T> : ColumnarData<T> {
|
||||||
|
/**
|
||||||
|
* The buffer of X values
|
||||||
|
*/
|
||||||
|
public val x: Buffer<X>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The buffer of Y values.
|
||||||
|
*/
|
||||||
|
public val y: Buffer<Y>
|
||||||
|
|
||||||
|
override fun get(symbol: Symbol): Buffer<T> = when (symbol) {
|
||||||
|
Symbol.x -> x
|
||||||
|
Symbol.y -> y
|
||||||
|
else -> error("A column for symbol $symbol not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, X : T, Y : T> XYColumnarData(x: Buffer<X>, y: Buffer<Y>): XYColumnarData<T, X, Y> {
|
||||||
|
require(x.size == y.size) { "Buffer size mismatch. x buffer size is ${x.size}, y buffer size is ${y.size}" }
|
||||||
|
return object : XYColumnarData<T, X, Y> {
|
||||||
|
override val size: Int = x.size
|
||||||
|
override val x: Buffer<X> = x
|
||||||
|
override val y: Buffer<Y> = y
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A zero-copy method to represent a [Structure2D] as a two-column x-y data.
|
||||||
|
* There could more than two columns in the structure.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T> Structure2D<T>.asXYData(xIndex: Int = 0, yIndex: Int = 1): XYColumnarData<T, T, T> {
|
||||||
|
require(shape[1] >= max(xIndex, yIndex)) { "Column index out of bounds" }
|
||||||
|
return object : XYColumnarData<T, T, T> {
|
||||||
|
override val size: Int get() = this@asXYData.shape[0]
|
||||||
|
override val x: Buffer<T> get() = columns[xIndex]
|
||||||
|
override val y: Buffer<T> get() = columns[yIndex]
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,21 @@
|
|||||||
|
package space.kscience.kmath.data
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A [XYColumnarData] with guaranteed [x], [y] and [z] columns designated by corresponding symbols.
|
||||||
|
* Inherits [XYColumnarData].
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface XYZColumnarData<T, out X : T, out Y : T, out Z : T> : XYColumnarData<T, X, Y> {
|
||||||
|
public val z: Buffer<Z>
|
||||||
|
|
||||||
|
override fun get(symbol: Symbol): Buffer<T> = when (symbol) {
|
||||||
|
Symbol.x -> x
|
||||||
|
Symbol.y -> y
|
||||||
|
Symbol.z -> z
|
||||||
|
else -> error("A column for symbol $symbol not found")
|
||||||
|
}
|
||||||
|
}
|
@ -1,15 +0,0 @@
|
|||||||
package space.kscience.kmath.misc
|
|
||||||
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A column-based data set with all columns of the same size (not necessary fixed in time).
|
|
||||||
* The column could be retrieved by a [get] operation.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public interface ColumnarData<out T> {
|
|
||||||
public val size: Int
|
|
||||||
|
|
||||||
public operator fun get(symbol: Symbol): Buffer<T>
|
|
||||||
}
|
|
||||||
|
|
@ -1,98 +0,0 @@
|
|||||||
package space.kscience.kmath.misc
|
|
||||||
|
|
||||||
import space.kscience.kmath.nd.Structure2D
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Pair of associated buffers for X and Y axes values.
|
|
||||||
*
|
|
||||||
* @param X the type of X values.
|
|
||||||
* @param Y the type of Y values.
|
|
||||||
*/
|
|
||||||
public interface XYPointSet<X, Y> {
|
|
||||||
/**
|
|
||||||
* The size of all the involved buffers.
|
|
||||||
*/
|
|
||||||
public val size: Int
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The buffer of X values.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public interface XYPointSet<T, X : T, Y : T> : ColumnarData<T> {
|
|
||||||
public val x: Buffer<X>
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The buffer of Y values.
|
|
||||||
*/
|
|
||||||
public val y: Buffer<Y>
|
|
||||||
|
|
||||||
override fun get(symbol: Symbol): Buffer<T> = when (symbol) {
|
|
||||||
Symbol.x -> x
|
|
||||||
Symbol.y -> y
|
|
||||||
else -> error("A column for symbol $symbol not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Triple of associated buffers for X, Y, and Z axes values.
|
|
||||||
*
|
|
||||||
* @param X the type of X values.
|
|
||||||
* @param Y the type of Y values.
|
|
||||||
* @param Z the type of Z values.
|
|
||||||
*/
|
|
||||||
public interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
|
|
||||||
/**
|
|
||||||
* The buffer of Z values.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public interface XYZPointSet<T, X : T, Y : T, Z : T> : XYPointSet<T, X, Y> {
|
|
||||||
public val z: Buffer<Z>
|
|
||||||
|
|
||||||
override fun get(symbol: Symbol): Buffer<T> = when (symbol) {
|
|
||||||
Symbol.x -> x
|
|
||||||
Symbol.y -> y
|
|
||||||
Symbol.z -> z
|
|
||||||
else -> error("A column for symbol $symbol not found")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
|
|
||||||
for (i in 0 until points.size - 1)
|
|
||||||
require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
|
|
||||||
}
|
|
||||||
|
|
||||||
public class NDStructureColumn<T>(public val structure: Structure2D<T>, public val column: Int) : Buffer<T> {
|
|
||||||
public override val size: Int
|
|
||||||
get() = structure.rowNum
|
|
||||||
|
|
||||||
init {
|
|
||||||
require(column < structure.colNum) { "Column index is outside of structure column range" }
|
|
||||||
}
|
|
||||||
|
|
||||||
public override operator fun get(index: Int): T = structure[index, column]
|
|
||||||
public override operator fun iterator(): Iterator<T> = sequence { repeat(size) { yield(get(it)) } }.iterator()
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public class BufferXYPointSet<T, X : T, Y : T>(
|
|
||||||
public override val x: Buffer<X>,
|
|
||||||
public override val y: Buffer<Y>,
|
|
||||||
) : XYPointSet<T, X, Y> {
|
|
||||||
public override val size: Int get() = x.size
|
|
||||||
|
|
||||||
init {
|
|
||||||
require(x.size == y.size) { "Sizes of x and y buffers should be the same" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T, T> {
|
|
||||||
require(shape[1] == 2) { "Structure second dimension should be of size 2" }
|
|
||||||
|
|
||||||
return object : XYPointSet<T, T, T> {
|
|
||||||
override val size: Int get() = this@asXYPointSet.shape[0]
|
|
||||||
override val x: Buffer<T> get() = NDStructureColumn(this@asXYPointSet, 0)
|
|
||||||
override val y: Buffer<T> get() = NDStructureColumn(this@asXYPointSet, 1)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,17 +1,17 @@
|
|||||||
@file:OptIn(UnstableKMathAPI::class)
|
@file:OptIn(UnstableKMathAPI::class)
|
||||||
|
|
||||||
package space.kscience.kmath.interpolation
|
package space.kscience.kmath.interpolation
|
||||||
|
|
||||||
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
import space.kscience.kmath.functions.PiecewisePolynomial
|
import space.kscience.kmath.functions.PiecewisePolynomial
|
||||||
import space.kscience.kmath.functions.value
|
import space.kscience.kmath.functions.value
|
||||||
import space.kscience.kmath.misc.BufferXYPointSet
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.misc.XYPointSet
|
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.asBuffer
|
||||||
|
|
||||||
public fun interface Interpolator<T, X : T, Y : T> {
|
public fun interface Interpolator<T, X : T, Y : T> {
|
||||||
public fun interpolate(points: XYPointSet<T, X, Y>): (X) -> Y
|
public fun interpolate(points: XYColumnarData<T, X, Y>): (X) -> Y
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T, T> {
|
public interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T, T> {
|
||||||
@ -19,9 +19,9 @@ public interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T,
|
|||||||
|
|
||||||
public fun getDefaultValue(): T = error("Out of bounds")
|
public fun getDefaultValue(): T = error("Out of bounds")
|
||||||
|
|
||||||
public fun interpolatePolynomials(points: XYPointSet<T, T, T>): PiecewisePolynomial<T>
|
public fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T>
|
||||||
|
|
||||||
override fun interpolate(points: XYPointSet<T, T, T>): (T) -> T = { x ->
|
override fun interpolate(points: XYColumnarData<T, T, T>): (T) -> T = { x ->
|
||||||
interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue()
|
interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -31,20 +31,20 @@ public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
|||||||
x: Buffer<T>,
|
x: Buffer<T>,
|
||||||
y: Buffer<T>,
|
y: Buffer<T>,
|
||||||
): PiecewisePolynomial<T> {
|
): PiecewisePolynomial<T> {
|
||||||
val pointSet = BufferXYPointSet(x, y)
|
val pointSet = XYColumnarData(x, y)
|
||||||
return interpolatePolynomials(pointSet)
|
return interpolatePolynomials(pointSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
||||||
data: Map<T, T>,
|
data: Map<T, T>,
|
||||||
): PiecewisePolynomial<T> {
|
): PiecewisePolynomial<T> {
|
||||||
val pointSet = BufferXYPointSet(data.keys.toList().asBuffer(), data.values.toList().asBuffer())
|
val pointSet = XYColumnarData(data.keys.toList().asBuffer(), data.values.toList().asBuffer())
|
||||||
return interpolatePolynomials(pointSet)
|
return interpolatePolynomials(pointSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
public fun <T : Comparable<T>> PolynomialInterpolator<T>.interpolatePolynomials(
|
||||||
data: List<Pair<T, T>>,
|
data: List<Pair<T, T>>,
|
||||||
): PiecewisePolynomial<T> {
|
): PiecewisePolynomial<T> {
|
||||||
val pointSet = BufferXYPointSet(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
|
val pointSet = XYColumnarData(data.map { it.first }.asBuffer(), data.map { it.second }.asBuffer())
|
||||||
return interpolatePolynomials(pointSet)
|
return interpolatePolynomials(pointSet)
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
package space.kscience.kmath.interpolation
|
package space.kscience.kmath.interpolation
|
||||||
|
|
||||||
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
import space.kscience.kmath.functions.OrderedPiecewisePolynomial
|
import space.kscience.kmath.functions.OrderedPiecewisePolynomial
|
||||||
import space.kscience.kmath.functions.PiecewisePolynomial
|
import space.kscience.kmath.functions.PiecewisePolynomial
|
||||||
import space.kscience.kmath.functions.Polynomial
|
import space.kscience.kmath.functions.Polynomial
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.misc.XYPointSet
|
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<*, T, *>) {
|
internal fun <T : Comparable<T>> insureSorted(points: XYColumnarData<*, T, *>) {
|
||||||
for (i in 0 until points.size - 1)
|
for (i in 0 until points.size - 1)
|
||||||
require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
|
require(points.x[i + 1] > points.x[i]) { "Input data is not sorted at index $i" }
|
||||||
}
|
}
|
||||||
@ -19,7 +19,7 @@ internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<*, T, *>) {
|
|||||||
*/
|
*/
|
||||||
public class LinearInterpolator<T : Comparable<T>>(public override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
public class LinearInterpolator<T : Comparable<T>>(public override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public override fun interpolatePolynomials(points: XYPointSet<T, T, T>): PiecewisePolynomial<T> = algebra {
|
public override fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
require(points.size > 0) { "Point array should not be empty" }
|
require(points.size > 0) { "Point array should not be empty" }
|
||||||
insureSorted(points)
|
insureSorted(points)
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
package space.kscience.kmath.interpolation
|
package space.kscience.kmath.interpolation
|
||||||
|
|
||||||
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
import space.kscience.kmath.functions.OrderedPiecewisePolynomial
|
import space.kscience.kmath.functions.OrderedPiecewisePolynomial
|
||||||
import space.kscience.kmath.functions.PiecewisePolynomial
|
import space.kscience.kmath.functions.PiecewisePolynomial
|
||||||
import space.kscience.kmath.functions.Polynomial
|
import space.kscience.kmath.functions.Polynomial
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.misc.XYPointSet
|
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.MutableBufferFactory
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
@ -23,7 +23,7 @@ public class SplineInterpolator<T : Comparable<T>>(
|
|||||||
//TODO possibly optimize zeroed buffers
|
//TODO possibly optimize zeroed buffers
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public override fun interpolatePolynomials(points: XYPointSet<T, T, T>): PiecewisePolynomial<T> = algebra {
|
public override fun interpolatePolynomials(points: XYColumnarData<T, T, T>): PiecewisePolynomial<T> = algebra {
|
||||||
require(points.size >= 3) { "Can't use spline interpolator with less than 3 points" }
|
require(points.size >= 3) { "Can't use spline interpolator with less than 3 points" }
|
||||||
insureSorted(points)
|
insureSorted(points)
|
||||||
// Number of intervals. The number of data points is n + 1.
|
// Number of intervals. The number of data points is n + 1.
|
||||||
|
@ -18,8 +18,10 @@ import space.kscience.kmath.operations.NumericAlgebra
|
|||||||
* @param A the [NumericAlgebra] of [T].
|
* @param A the [NumericAlgebra] of [T].
|
||||||
* @property expr the underlying [MstExpression].
|
* @property expr the underlying [MstExpression].
|
||||||
*/
|
*/
|
||||||
public inline class DifferentiableMstExpression<T, A>(public val expr: MstExpression<T, A>) :
|
public inline class DifferentiableMstExpression<T: Number, A>(
|
||||||
DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T>, T : Number {
|
public val expr: MstExpression<T, A>,
|
||||||
|
) : DifferentiableExpression<T, MstExpression<T, A>> where A : NumericAlgebra<T> {
|
||||||
|
|
||||||
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
public constructor(algebra: A, mst: MST) : this(MstExpression(algebra, mst))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
package space.kscience.kmath.optimization
|
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
|
||||||
import space.kscience.kmath.misc.StringSymbol
|
|
||||||
import space.kscience.kmath.misc.Symbol
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
|
|
||||||
public interface DataFit<T : Any> : Optimization<T> {
|
|
||||||
|
|
||||||
public fun modelAndData(
|
|
||||||
x: Buffer<T>,
|
|
||||||
y: Buffer<T>,
|
|
||||||
yErr: Buffer<T>,
|
|
||||||
model: DifferentiableExpression<T, *>,
|
|
||||||
xSymbol: Symbol = StringSymbol("x"),
|
|
||||||
)
|
|
||||||
}
|
|
@ -4,45 +4,31 @@ import space.kscience.kmath.expressions.AutoDiffProcessor
|
|||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.expressions.ExpressionAlgebra
|
import space.kscience.kmath.expressions.ExpressionAlgebra
|
||||||
import space.kscience.kmath.misc.StringSymbol
|
|
||||||
import space.kscience.kmath.misc.Symbol
|
import space.kscience.kmath.misc.Symbol
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
import kotlin.math.pow
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A likelihood function optimization problem
|
* A likelihood function optimization problem with provided derivatives
|
||||||
*/
|
*/
|
||||||
public interface FunctionOptimization<T: Any>: Optimization<T>, DataFit<T> {
|
public interface FunctionOptimization<T : Any> : Optimization<T> {
|
||||||
|
/**
|
||||||
|
* The optimization direction. If true search for function maximum, if false, search for the minimum
|
||||||
|
*/
|
||||||
|
public var maximize: Boolean
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define the initial guess for the optimization problem
|
* Define the initial guess for the optimization problem
|
||||||
*/
|
*/
|
||||||
public fun initialGuess(map: Map<Symbol, T>)
|
public fun initialGuess(map: Map<Symbol, T>)
|
||||||
|
|
||||||
/**
|
|
||||||
* Set an objective function expression
|
|
||||||
*/
|
|
||||||
public fun expression(expression: Expression<T>)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set a differentiable expression as objective function as function and gradient provider
|
* Set a differentiable expression as objective function as function and gradient provider
|
||||||
*/
|
*/
|
||||||
public fun diffExpression(expression: DifferentiableExpression<T, Expression<T>>)
|
public fun diffFunction(expression: DifferentiableExpression<T, Expression<T>>)
|
||||||
|
|
||||||
override fun modelAndData(
|
public companion object {
|
||||||
x: Buffer<T>,
|
|
||||||
y: Buffer<T>,
|
|
||||||
yErr: Buffer<T>,
|
|
||||||
model: DifferentiableExpression<T, *>,
|
|
||||||
xSymbol: Symbol,
|
|
||||||
) {
|
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public companion object{
|
|
||||||
/**
|
/**
|
||||||
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
* Generate a chi squared expression from given x-y-sigma data and inline model. Provides automatic differentiation
|
||||||
*/
|
*/
|
||||||
@ -70,46 +56,22 @@ public interface FunctionOptimization<T: Any>: Optimization<T>, DataFit<T> {
|
|||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
|
|
||||||
*/
|
|
||||||
public fun chiSquared(
|
|
||||||
x: Buffer<Double>,
|
|
||||||
y: Buffer<Double>,
|
|
||||||
yErr: Buffer<Double>,
|
|
||||||
model: Expression<Double>,
|
|
||||||
xSymbol: Symbol = StringSymbol("x"),
|
|
||||||
): Expression<Double> {
|
|
||||||
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
|
||||||
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
|
||||||
|
|
||||||
return Expression { arguments ->
|
|
||||||
x.indices.sumByDouble {
|
|
||||||
val xValue = x[it]
|
|
||||||
val yValue = y[it]
|
|
||||||
val yErrValue = yErr[it]
|
|
||||||
val modifiedArgs = arguments + (xSymbol to xValue)
|
|
||||||
val modelValue = model(modifiedArgs)
|
|
||||||
((yValue - modelValue) / yErrValue).pow(2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
* Define a chi-squared-based objective function
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, F : FunctionOptimization<T>> Expression<T>.optimizeWith(
|
public fun <T: Any, I : Any, A> FunctionOptimization<T>.chiSquared(
|
||||||
factory: OptimizationProblemFactory<T, F>,
|
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
vararg symbols: Symbol,
|
x: Buffer<T>,
|
||||||
configuration: F.() -> Unit,
|
y: Buffer<T>,
|
||||||
): OptimizationResult<T> {
|
yErr: Buffer<T>,
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
model: A.(I) -> I,
|
||||||
val problem = factory(symbols.toList(), configuration)
|
) where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> {
|
||||||
problem.expression(this)
|
val chiSquared = FunctionOptimization.chiSquared(autoDiff, x, y, yErr, model)
|
||||||
return problem.optimize()
|
diffFunction(chiSquared)
|
||||||
|
maximize = false
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -122,6 +84,6 @@ public fun <T : Any, F : FunctionOptimization<T>> DifferentiableExpression<T, Ex
|
|||||||
): OptimizationResult<T> {
|
): OptimizationResult<T> {
|
||||||
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
val problem = factory(symbols.toList(), configuration)
|
val problem = factory(symbols.toList(), configuration)
|
||||||
problem.diffExpression(this)
|
problem.diffFunction(this)
|
||||||
return problem.optimize()
|
return problem.optimize()
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,69 @@
|
|||||||
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.indices
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A likelihood function optimization problem
|
||||||
|
*/
|
||||||
|
public interface NoDerivFunctionOptimization<T : Any> : Optimization<T> {
|
||||||
|
/**
|
||||||
|
* The optimization direction. If true search for function maximum, if false, search for the minimum
|
||||||
|
*/
|
||||||
|
public var maximize: Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define the initial guess for the optimization problem
|
||||||
|
*/
|
||||||
|
public fun initialGuess(map: Map<Symbol, T>)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set an objective function expression
|
||||||
|
*/
|
||||||
|
public fun function(expression: Expression<T>)
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
/**
|
||||||
|
* Generate a chi squared expression from given x-y-sigma model represented by an expression. Does not provide derivatives
|
||||||
|
*/
|
||||||
|
public fun chiSquared(
|
||||||
|
x: Buffer<Double>,
|
||||||
|
y: Buffer<Double>,
|
||||||
|
yErr: Buffer<Double>,
|
||||||
|
model: Expression<Double>,
|
||||||
|
xSymbol: Symbol = Symbol.x,
|
||||||
|
): Expression<Double> {
|
||||||
|
require(x.size == y.size) { "X and y buffers should be of the same size" }
|
||||||
|
require(y.size == yErr.size) { "Y and yErr buffer should of the same size" }
|
||||||
|
|
||||||
|
return Expression { arguments ->
|
||||||
|
x.indices.sumByDouble {
|
||||||
|
val xValue = x[it]
|
||||||
|
val yValue = y[it]
|
||||||
|
val yErrValue = yErr[it]
|
||||||
|
val modifiedArgs = arguments + (xSymbol to xValue)
|
||||||
|
val modelValue = model(modifiedArgs)
|
||||||
|
((yValue - modelValue) / yErrValue).pow(2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Optimize expression without derivatives using specific [OptimizationProblemFactory]
|
||||||
|
*/
|
||||||
|
public fun <T : Any, F : NoDerivFunctionOptimization<T>> Expression<T>.noDerivOptimizeWith(
|
||||||
|
factory: OptimizationProblemFactory<T, F>,
|
||||||
|
vararg symbols: Symbol,
|
||||||
|
configuration: F.() -> Unit,
|
||||||
|
): OptimizationResult<T> {
|
||||||
|
require(symbols.isNotEmpty()) { "Must provide a list of symbols for optimization" }
|
||||||
|
val problem = factory(symbols.toList(), configuration)
|
||||||
|
problem.function(this)
|
||||||
|
return problem.optimize()
|
||||||
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
|
import space.kscience.kmath.data.ColumnarData
|
||||||
|
import space.kscience.kmath.expressions.AutoDiffProcessor
|
||||||
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
|
import space.kscience.kmath.expressions.Expression
|
||||||
|
import space.kscience.kmath.expressions.ExpressionAlgebra
|
||||||
|
import space.kscience.kmath.misc.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
|
import space.kscience.kmath.operations.Field
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public interface XYFit<T : Any> : Optimization<T> {
|
||||||
|
|
||||||
|
public val algebra: Field<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set X-Y data for this fit optionally including x and y errors
|
||||||
|
*/
|
||||||
|
public fun data(
|
||||||
|
dataSet: ColumnarData<T>,
|
||||||
|
xSymbol: Symbol,
|
||||||
|
ySymbol: Symbol,
|
||||||
|
xErrSymbol: Symbol? = null,
|
||||||
|
yErrSymbol: Symbol? = null,
|
||||||
|
)
|
||||||
|
|
||||||
|
public fun model(model: (T) -> DifferentiableExpression<T, *>)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the differentiable model for this fit
|
||||||
|
*/
|
||||||
|
public fun <I : Any, A> model(
|
||||||
|
autoDiff: AutoDiffProcessor<T, I, A, Expression<T>>,
|
||||||
|
modelFunction: A.(I) -> I,
|
||||||
|
): Unit where A : ExtendedField<I>, A : ExpressionAlgebra<T, I> = model { arg ->
|
||||||
|
autoDiff.process { modelFunction(const(arg)) }
|
||||||
|
}
|
||||||
|
}
|
@ -4,12 +4,11 @@ pluginManagement {
|
|||||||
mavenLocal()
|
mavenLocal()
|
||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
jcenter()
|
jcenter()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion = "0.9.1"
|
val toolsVersion = "0.9.3"
|
||||||
val kotlinVersion = "1.4.31"
|
val kotlinVersion = "1.4.32"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
id("kotlinx.benchmark") version "0.2.0-dev-20"
|
||||||
|
Loading…
Reference in New Issue
Block a user