diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt index 467ea3201..301247698 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt @@ -11,7 +11,7 @@ import space.kscience.kmath.operations.* import space.kscience.kmath.operations.Float64BufferOps.Companion.div import space.kscience.kmath.operations.Float64BufferOps.Companion.pow import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.last +import space.kscience.kmath.structures.asBuffer import kotlin.math.sign /** @@ -47,24 +47,40 @@ public class EmpiricalModeDecomposition ( */ private fun findMean(signal: Series): Series? = (seriesAlgebra) { val interpolator = SplineInterpolator(Float64Field) - fun generateEnvelope(extrema: List): Series { + fun generateEnvelope(extrema: List, paddedExtremeValues: DoubleArray): Series { val envelopeFunction = interpolator.interpolate( Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() }, - Buffer(extrema.size) { signal[extrema[it]] } + paddedExtremeValues.asBuffer() ) return signal.mapWithLabel { _, label -> // For some reason PolynomialInterpolator is exclusive and the right boundary // TODO Notify interpolator authors - envelopeFunction(label.toDouble()) ?: signal.last() + envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last() // need to make the interpolator yield values outside boundaries? } } - val maxima = signal.paddedMaxima() - val minima = signal.paddedMinima() - return if (maxima.size < 3 || minima.size < 3) null else { - val upperEnvelope = generateEnvelope(maxima) - val lowerEnvelope = generateEnvelope(minima) - return upperEnvelope.zip(lowerEnvelope) { upper, lower -> upper + lower / 2.0 } + // Extrema padding (experimental) TODO padding needs a dedicated function + val maxima = listOf(0) + signal.peaks() + (signal.size - 1) + val maxValues = DoubleArray(maxima.size) { signal[maxima[it]] } + if (maxValues[0] < maxValues[1]) { + maxValues[0] = maxValues[1] + } + if (maxValues.last() < maxValues[maxValues.lastIndex - 1]) { + maxValues[maxValues.lastIndex] = maxValues[maxValues.lastIndex - 1] + } + + val minima = listOf(0) + signal.troughs() + (signal.size - 1) + val minValues = DoubleArray(minima.size) { signal[minima[it]] } + if (minValues[0] > minValues[1]) { + minValues[0] = minValues[1] + } + if (minValues.last() > minValues[minValues.lastIndex - 1]) { + minValues[minValues.lastIndex] = minValues[minValues.lastIndex - 1] + } + return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return? + val upperEnvelope = generateEnvelope(maxima, maxValues) + val lowerEnvelope = generateEnvelope(minima, minValues) + return (upperEnvelope + lowerEnvelope).map { it * 0.5 } } } @@ -186,43 +202,12 @@ private fun > isExtreme(prev: T, elem: T, next: T): Boolean = /** * Brute force count all extrema of a series. */ +@Deprecated("Does not match the algorithm currently in use.") private fun Series.countExtrema(): Int { require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" } return (1 .. size - 2).count { isExtreme(this[it - 1], this[it], this[it + 1]) } } - -/** - * Retrieve indices of knot points for spline interpolation matching the predicate. - * The first and the last points of a series are always included. - */ -private fun > Series.knotPoints(predicate: (T, T, T) -> Boolean): List { - require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" } - val points = mutableListOf(0) - for (index in 1 .. size - 2) { - val left = this[index - 1] - val middle = this[index] - val right = this[index + 1] - if (predicate(left, middle, right)) points.add(index) - } - points.add(size - 1) - return points -} - -/** - * Retrieve indices of knot points used to construct an upper envelope, - * namely maxima together with the first last point in a series. - */ -private fun > Series.paddedMaxima(): List = - knotPoints { left, middle, right -> (middle > left && middle > right) } - -/** - * Retrieve indices of knot points used to construct a lower envelope, - * namely minima together with the first last point in a series. - */ -private fun > Series.paddedMinima(): List = - knotPoints { left, middle, right -> (middle < left && middle < right) } - /** * Check whether the numbers of zeroes and extrema of a series differ by no more than 1. * This is a necessary condition of an empirical mode.