Fix CM DerivativeStructureField constants
This commit is contained in:
parent
fbe1ab94a4
commit
6f31ddba30
@ -14,8 +14,10 @@ public class DerivativeStructureField(
|
|||||||
public val order: Int,
|
public val order: Int,
|
||||||
bindings: Map<Symbol, Double>,
|
bindings: Map<Symbol, Double>,
|
||||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0) }
|
public val numberOfVariables: Int = bindings.size
|
||||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0, 1.0) }
|
|
||||||
|
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||||
|
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||||
@ -32,8 +34,6 @@ public class DerivativeStructureField(
|
|||||||
override fun hashCode(): Int = identity.hashCode()
|
override fun hashCode(): Int = identity.hashCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val numberOfVariables: Int = bindings.size
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Identity-based symbol bindings map
|
* Identity-based symbol bindings map
|
||||||
*/
|
*/
|
||||||
@ -41,7 +41,7 @@ public class DerivativeStructureField(
|
|||||||
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
|
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
|
||||||
}.toMap()
|
}.toMap()
|
||||||
|
|
||||||
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, 0, value)
|
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
|
||||||
|
|
||||||
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution
|
|||||||
import kscience.kmath.stat.Fitting
|
import kscience.kmath.stat.Fitting
|
||||||
import kscience.kmath.stat.RandomGenerator
|
import kscience.kmath.stat.RandomGenerator
|
||||||
import kscience.kmath.stat.normal
|
import kscience.kmath.stat.normal
|
||||||
import kscience.kmath.structures.asBuffer
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
@ -53,7 +52,7 @@ internal class OptimizeTest {
|
|||||||
it.pow(2) + it + 1 + chain.nextDouble()
|
it.pow(2) + it + 1 + chain.nextDouble()
|
||||||
}
|
}
|
||||||
val yErr = x.map { sigma }
|
val yErr = x.map { sigma }
|
||||||
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x ->
|
val chi2 = Fitting.chiSquared(x, y, yErr) { x ->
|
||||||
val cWithDefault = bindOrNull(c) ?: one
|
val cWithDefault = bindOrNull(c) ?: one
|
||||||
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
bind(a) * x.pow(2) + bind(b) * x + cWithDefault
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user