Fix CM DerivativeStructureField constants

This commit is contained in:
Alexander Nozik 2020-10-29 19:50:45 +03:00
parent fbe1ab94a4
commit 6f31ddba30
2 changed files with 6 additions and 7 deletions

View File

@ -14,8 +14,10 @@ public class DerivativeStructureField(
public val order: Int, public val order: Int,
bindings: Map<Symbol, Double>, bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> { ) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0) } public val numberOfVariables: Int = bindings.size
public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0, 1.0) }
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
/** /**
* A class that implements both [DerivativeStructure] and a [Symbol] * A class that implements both [DerivativeStructure] and a [Symbol]
@ -32,8 +34,6 @@ public class DerivativeStructureField(
override fun hashCode(): Int = identity.hashCode() override fun hashCode(): Int = identity.hashCode()
} }
public val numberOfVariables: Int = bindings.size
/** /**
* Identity-based symbol bindings map * Identity-based symbol bindings map
*/ */
@ -41,7 +41,7 @@ public class DerivativeStructureField(
key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value)
}.toMap() }.toMap()
override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, 0, value) override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value)
public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity]

View File

@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution
import kscience.kmath.stat.Fitting import kscience.kmath.stat.Fitting
import kscience.kmath.stat.RandomGenerator import kscience.kmath.stat.RandomGenerator
import kscience.kmath.stat.normal import kscience.kmath.stat.normal
import kscience.kmath.structures.asBuffer
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import kotlin.math.pow import kotlin.math.pow
@ -53,7 +52,7 @@ internal class OptimizeTest {
it.pow(2) + it + 1 + chain.nextDouble() it.pow(2) + it + 1 + chain.nextDouble()
} }
val yErr = x.map { sigma } val yErr = x.map { sigma }
val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> val chi2 = Fitting.chiSquared(x, y, yErr) { x ->
val cWithDefault = bindOrNull(c) ?: one val cWithDefault = bindOrNull(c) ?: one
bind(a) * x.pow(2) + bind(b) * x + cWithDefault bind(a) * x.pow(2) + bind(b) * x + cWithDefault
} }