STUD-13. Improve numerical stability of mean statistic algorithm
Replace naive summation that prone to floating-point errors (loss of precision) by Welford’s Online Algorithm which updates mean incrementally and more numerically stable. The cons is slightly higher computational overhead.
This commit is contained in:
parent
288ec467e6
commit
d43ce15b99
kmath-stat/src
commonMain/kotlin/space/kscience/kmath/stat
commonTest/kotlin/space/kscience/kmath/stat
@ -9,18 +9,28 @@ import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.*
|
||||
|
||||
/**
|
||||
* Arithmetic mean
|
||||
* Arithmetic mean.
|
||||
*
|
||||
* To limit numeric errors, the value of the statistic is computed using the recursive updating algorithm:
|
||||
*
|
||||
* Initialize `m` = the first value. For each additional value, update using
|
||||
* ```
|
||||
* m = m + (new value - m) / (number of observations)
|
||||
* ```
|
||||
*
|
||||
*/
|
||||
public class Mean<T>(
|
||||
private val field: Field<T>,
|
||||
) : ComposableStatistic<T, Pair<T, Int>, T>, BlockingStatistic<T, T> {
|
||||
|
||||
override fun evaluateBlocking(data: Buffer<T>): T = with(field) {
|
||||
var res = zero
|
||||
var mean = zero
|
||||
var delta: T
|
||||
for (i in data.indices) {
|
||||
res += data[i]
|
||||
delta = data[i] - mean
|
||||
mean += delta / (i + 1)
|
||||
}
|
||||
res / data.size
|
||||
return mean
|
||||
}
|
||||
|
||||
override suspend fun evaluate(data: Buffer<T>): T = super<ComposableStatistic>.evaluate(data)
|
||||
@ -47,8 +57,6 @@ public class Mean<T>(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//TODO replace with optimized version which respects overflow
|
||||
public val Float64Field.mean: Mean<Float64> get() = Mean(Float64Field)
|
||||
public val Int32Ring.mean: Mean<Int> get() = Mean(Int32Field)
|
||||
public val Int64Ring.mean: Mean<Long> get() = Mean(Int64Field)
|
||||
|
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.stat
|
||||
|
||||
import space.kscience.kmath.operations.Float64Field
|
||||
import space.kscience.kmath.random.RandomGenerator
|
||||
import space.kscience.kmath.structures.Float64Buffer
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class MeanStatisticBasicTest {
|
||||
companion object {
|
||||
val float64Sample = RandomGenerator.default(123).nextDoubleBuffer(100)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun meanFloat64() {
|
||||
assertEquals(0.488, Float64Field.mean(float64Sample), 0.0002)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun numericalStability(){
|
||||
val data = doubleArrayOf(1.0e4, 1.0, 1.0, 1.0, 1.0);
|
||||
assertEquals(2000.8, Float64Field.mean.evaluateBlocking(Float64Buffer(data)));
|
||||
}
|
||||
|
||||
}
|
@ -13,7 +13,7 @@ import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
class TestBasicStatistics {
|
||||
class MedianStatisticBasicTest {
|
||||
companion object {
|
||||
val float64Sample = RandomGenerator.default(123).nextDoubleBuffer(100)
|
||||
}
|
||||
@ -24,8 +24,4 @@ class TestBasicStatistics {
|
||||
assertEquals(0.5055, Float64Field.median(float64Sample.slice { 0..<last }), 0.0005)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun meanFloat64() {
|
||||
assertEquals(0.488, Float64Field.mean(float64Sample), 0.0002)
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user