From 6f31ddba301b2cfb2a8ebbc900afb566b510f612 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 29 Oct 2020 19:50:45 +0300 Subject: [PATCH] Fix CM DerivativeStructureField constants --- .../expressions/DerivativeStructureExpression.kt | 10 +++++----- .../kmath/commons/optimization/OptimizeTest.kt | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index e4311a56b..244dc1314 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -14,8 +14,10 @@ public class DerivativeStructureField( public val order: Int, bindings: Map, ) : ExtendedField, ExpressionAlgebra { - public override val zero: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0) } - public override val one: DerivativeStructure by lazy { DerivativeStructure(bindings.size, 0, 1.0) } + public val numberOfVariables: Int = bindings.size + + public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } + public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } /** * A class that implements both [DerivativeStructure] and a [Symbol] @@ -32,8 +34,6 @@ public class DerivativeStructureField( override fun hashCode(): Int = identity.hashCode() } - public val numberOfVariables: Int = bindings.size - /** * Identity-based symbol bindings map */ @@ -41,7 +41,7 @@ public class DerivativeStructureField( key.identity to DerivativeStructureSymbol(numberOfVariables, index, key, value) }.toMap() - override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, 0, value) + override fun const(value: Double): DerivativeStructure = DerivativeStructure(numberOfVariables, order, value) public override fun bindOrNull(symbol: Symbol): DerivativeStructureSymbol? = variables[symbol.identity] diff --git a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt index 4384a5124..fa1978f95 100644 --- a/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/test/kotlin/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -6,7 +6,6 @@ import kscience.kmath.stat.Distribution import kscience.kmath.stat.Fitting import kscience.kmath.stat.RandomGenerator import kscience.kmath.stat.normal -import kscience.kmath.structures.asBuffer import org.junit.jupiter.api.Test import kotlin.math.pow @@ -53,7 +52,7 @@ internal class OptimizeTest { it.pow(2) + it + 1 + chain.nextDouble() } val yErr = x.map { sigma } - val chi2 = Fitting.chiSquared(x.asBuffer(), y.asBuffer(), yErr.asBuffer()) { x -> + val chi2 = Fitting.chiSquared(x, y, yErr) { x -> val cWithDefault = bindOrNull(c) ?: one bind(a) * x.pow(2) + bind(b) * x + cWithDefault }