STUD-13. Improve numerical stability of mean statistic algorithm

Replace naive summation that prone to floating-point errors (loss of precision) by Welford’s Online Algorithm which updates mean incrementally and more numerically stable.
 The cons is slightly higher computational overhead.
This commit is contained in:
qwazer 2025-05-05 15:34:49 +03:00
parent 288ec467e6
commit d43ce15b99
3 changed files with 45 additions and 11 deletions
kmath-stat/src
commonMain/kotlin/space/kscience/kmath/stat
commonTest/kotlin/space/kscience/kmath/stat

@ -9,18 +9,28 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
/**
* Arithmetic mean
* Arithmetic mean.
*
* To limit numeric errors, the value of the statistic is computed using the recursive updating algorithm:
*
* Initialize `m` = the first value. For each additional value, update using
* ```
* m = m + (new value - m) / (number of observations)
* ```
*
*/
public class Mean<T>(
private val field: Field<T>,
) : ComposableStatistic<T, Pair<T, Int>, T>, BlockingStatistic<T, T> {
override fun evaluateBlocking(data: Buffer<T>): T = with(field) {
var res = zero
var mean = zero
var delta: T
for (i in data.indices) {
res += data[i]
delta = data[i] - mean
mean += delta / (i + 1)
}
res / data.size
return mean
}
override suspend fun evaluate(data: Buffer<T>): T = super<ComposableStatistic>.evaluate(data)
@ -47,8 +57,6 @@ public class Mean<T>(
}
}
//TODO replace with optimized version which respects overflow
public val Float64Field.mean: Mean<Float64> get() = Mean(Float64Field)
public val Int32Ring.mean: Mean<Int> get() = Mean(Int32Field)
public val Int64Ring.mean: Mean<Long> get() = Mean(Int64Field)

@ -0,0 +1,30 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.stat
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.structures.Float64Buffer
import kotlin.test.Test
import kotlin.test.assertEquals
class MeanStatisticBasicTest {
companion object {
val float64Sample = RandomGenerator.default(123).nextDoubleBuffer(100)
}
@Test
fun meanFloat64() {
assertEquals(0.488, Float64Field.mean(float64Sample), 0.0002)
}
@Test
fun numericalStability(){
val data = doubleArrayOf(1.0e4, 1.0, 1.0, 1.0, 1.0);
assertEquals(2000.8, Float64Field.mean.evaluateBlocking(Float64Buffer(data)));
}
}

@ -13,7 +13,7 @@ import kotlin.test.Test
import kotlin.test.assertEquals
@OptIn(UnstableKMathAPI::class)
class TestBasicStatistics {
class MedianStatisticBasicTest {
companion object {
val float64Sample = RandomGenerator.default(123).nextDoubleBuffer(100)
}
@ -24,8 +24,4 @@ class TestBasicStatistics {
assertEquals(0.5055, Float64Field.median(float64Sample.slice { 0..<last }), 0.0005)
}
@Test
fun meanFloat64() {
assertEquals(0.488, Float64Field.mean(float64Sample), 0.0002)
}
}