Compare commits

...

5 Commits

2 changed files with 121 additions and 59 deletions

View File

@ -11,7 +11,7 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.Float64BufferOps.Companion.div import space.kscience.kmath.operations.Float64BufferOps.Companion.div
import space.kscience.kmath.operations.Float64BufferOps.Companion.pow import space.kscience.kmath.operations.Float64BufferOps.Companion.pow
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.last import space.kscience.kmath.structures.asBuffer
import kotlin.math.sign import kotlin.math.sign
/** /**
@ -47,24 +47,40 @@ public class EmpiricalModeDecomposition<BA, L: Number> (
*/ */
private fun findMean(signal: Series<Double>): Series<Double>? = (seriesAlgebra) { private fun findMean(signal: Series<Double>): Series<Double>? = (seriesAlgebra) {
val interpolator = SplineInterpolator(Float64Field) val interpolator = SplineInterpolator(Float64Field)
fun generateEnvelope(extrema: List<Int>): Series<Double> { fun generateEnvelope(extrema: List<Int>, paddedExtremeValues: DoubleArray): Series<Double> {
val envelopeFunction = interpolator.interpolate( val envelopeFunction = interpolator.interpolate(
Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() }, Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() },
Buffer(extrema.size) { signal[extrema[it]] } paddedExtremeValues.asBuffer()
) )
return signal.mapWithLabel { _, label -> return signal.mapWithLabel { _, label ->
// For some reason PolynomialInterpolator is exclusive and the right boundary // For some reason PolynomialInterpolator is exclusive and the right boundary
// TODO Notify interpolator authors // TODO Notify interpolator authors
envelopeFunction(label.toDouble()) ?: signal.last() envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last()
// need to make the interpolator yield values outside boundaries? // need to make the interpolator yield values outside boundaries?
} }
} }
val maxima = signal.paddedMaxima() // Extrema padding (experimental) TODO padding needs a dedicated function
val minima = signal.paddedMinima() val maxima = listOf(0) + signal.peaks() + (signal.size - 1)
return if (maxima.size < 3 || minima.size < 3) null else { val maxValues = DoubleArray(maxima.size) { signal[maxima[it]] }
val upperEnvelope = generateEnvelope(maxima) if (maxValues[0] < maxValues[1]) {
val lowerEnvelope = generateEnvelope(minima) maxValues[0] = maxValues[1]
return upperEnvelope.zip(lowerEnvelope) { upper, lower -> upper + lower / 2.0 } }
if (maxValues.last() < maxValues[maxValues.lastIndex - 1]) {
maxValues[maxValues.lastIndex] = maxValues[maxValues.lastIndex - 1]
}
val minima = listOf(0) + signal.troughs() + (signal.size - 1)
val minValues = DoubleArray(minima.size) { signal[minima[it]] }
if (minValues[0] > minValues[1]) {
minValues[0] = minValues[1]
}
if (minValues.last() > minValues[minValues.lastIndex - 1]) {
minValues[minValues.lastIndex] = minValues[minValues.lastIndex - 1]
}
return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return?
val upperEnvelope = generateEnvelope(maxima, maxValues)
val lowerEnvelope = generateEnvelope(minima, minValues)
return (upperEnvelope + lowerEnvelope).map { it * 0.5 }
} }
} }
@ -76,27 +92,19 @@ public class EmpiricalModeDecomposition<BA, L: Number> (
* @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode. * @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode.
* Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise. * Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise.
*/ */
private fun sift(signal: Series<Double>): SiftingResult = (seriesAlgebra) { private fun sift(signal: Series<Double>): SiftingResult = siftInner(signal, 1, 0)
val mean = findMean(signal) ?: return SiftingResult.NotEnoughExtrema
val protoMode = signal.zip(mean) { s, m -> s - m }
val sNumber = if (protoMode.sCondition()) 1 else 0
return when {
maxSiftIterations == 1 -> SiftingResult.MaxIterationsReached(protoMode)
sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(protoMode)
relativeDifference(protoMode, signal) < siftingDelta * signal.size -> SiftingResult.DeltaReached(protoMode)
else -> siftInner(protoMode, 2, sNumber)
}
}
/** /**
* Compute a single iteration of the sifting process. * Compute a single iteration of the sifting process.
*/ */
private fun siftInner( private tailrec fun siftInner(
prevMode: Series<Double>, prevMode: Series<Double>,
iterationNumber: Int, iterationNumber: Int,
sNumber: Int sNumber: Int
): SiftingResult = (seriesAlgebra) { ): SiftingResult = (seriesAlgebra) {
val mean = findMean(prevMode) ?: return SiftingResult.SignalFlattened(prevMode) val mean = findMean(prevMode) ?:
return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema
else SiftingResult.SignalFlattened(prevMode)
val mode = prevMode.zip(mean) { p, m -> p - m } val mode = prevMode.zip(mean) { p, m -> p - m }
val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber
return when { return when {
@ -188,49 +196,14 @@ private fun <BA> SeriesAlgebra<Double, *, BA, *>.relativeDifference(
.div(previous pow 2) .div(previous pow 2)
.fold(0.0) { acc, d -> acc + d } // TODO replace with Series<>.sum() method when it's implemented .fold(0.0) { acc, d -> acc + d } // TODO replace with Series<>.sum() method when it's implemented
private fun <T: Comparable<T>> isExtreme(prev: T, elem: T, next: T): Boolean =
(elem > prev && elem > next) || (elem < prev && elem < next)
/** /**
* Brute force count all extrema of a series. * Brute force count all extrema of a series.
*/ */
private fun Series<Double>.countExtrema(): Int { private fun Series<Double>.countExtrema(): Int {
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" } require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
return (1 .. size - 2).count { isExtreme(this[it - 1], this[it], this[it + 1]) } return peaks().size + troughs().size
} }
/**
* Retrieve indices of knot points for spline interpolation matching the predicate.
* The first and the last points of a series are always included.
*/
private fun <T: Comparable<T>> Series<T>.knotPoints(predicate: (T, T, T) -> Boolean): List<Int> {
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
val points = mutableListOf(0)
for (index in 1 .. size - 2) {
val left = this[index - 1]
val middle = this[index]
val right = this[index + 1]
if (predicate(left, middle, right)) points.add(index)
}
points.add(size - 1)
return points
}
/**
* Retrieve indices of knot points used to construct an upper envelope,
* namely maxima together with the first last point in a series.
*/
private fun <T: Comparable<T>> Series<T>.paddedMaxima(): List<Int> =
knotPoints { left, middle, right -> (middle > left && middle > right) }
/**
* Retrieve indices of knot points used to construct a lower envelope,
* namely minima together with the first last point in a series.
*/
private fun <T: Comparable<T>> Series<T>.paddedMinima(): List<Int> =
knotPoints { left, middle, right -> (middle < left && middle < right) }
/** /**
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1. * Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
* This is a necessary condition of an empirical mode. * This is a necessary condition of an empirical mode.

View File

@ -0,0 +1,89 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.series
public enum class PlateauEdgePolicy {
/**
* Midpoints of plateau are returned, edges not belonging to a plateau are ignored.
*
* A midpoint is the index closest to the average of indices of the left and right edges
* of the plateau:
*
* `val midpoint = ((leftEdge + rightEdge) / 2).toInt`
*/
AVERAGE,
/**
* Both left and right edges are returned.
*/
KEEP_ALL_EDGES,
/**
* Only right edges are returned.
*/
KEEP_RIGHT_EDGES,
/**
* Only left edges are returned.
*/
KEEP_LEFT_EDGES,
/**
* Ignore plateau, only peaks (troughs) with values strictly greater (less)
* than values of the adjacent points are returned.
*/
IGNORE
}
public fun <T: Comparable<T>> Series<T>.peaks(
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this > other }, { other -> this >= other })
public fun <T: Comparable<T>> Series<T>.troughs(
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this < other }, { other -> this <= other })
private fun <T: Comparable<T>> Series<T>.findPeaks(
plateauPolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE,
cmpStrong: T.(T) -> Boolean,
cmpWeak: T.(T) -> Boolean
): List<Int> {
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
if (plateauPolicy == PlateauEdgePolicy.AVERAGE) return peaksWithPlateau(cmpStrong)
fun peakCriterion(left: T, middle: T, right: T): Boolean = when(plateauPolicy) {
PlateauEdgePolicy.KEEP_LEFT_EDGES -> middle.cmpStrong(left) && middle.cmpWeak(right)
PlateauEdgePolicy.KEEP_RIGHT_EDGES -> middle.cmpWeak(left) && middle.cmpStrong(right)
PlateauEdgePolicy.KEEP_ALL_EDGES ->
(middle.cmpStrong(left) && middle.cmpWeak(right)) || (middle.cmpWeak(left) && middle.cmpStrong(right))
else -> middle.cmpStrong(right) && middle.cmpStrong(left)
}
val indices = mutableListOf<Int>()
for (index in 1 .. size - 2) {
val left = this[index - 1]
val middle = this[index]
val right = this[index + 1]
if (peakCriterion(left, middle, right)) indices.add(index)
}
return indices
}
private fun <T: Comparable<T>> Series<T>.peaksWithPlateau(cmpStrong: T.(T) -> Boolean): List<Int> {
val peaks = mutableListOf<Int>()
tailrec fun peaksPlateauInner(index: Int) {
val nextUnequal = (index + 1 ..< size).firstOrNull { this[it] != this[index] } ?: (size - 1)
val newIndex = if (this[index].cmpStrong(this[index - 1]) && this[index].cmpStrong(this[nextUnequal])) {
peaks.add((index + nextUnequal) / 2)
nextUnequal
} else index + 1
if (newIndex < size - 1) peaksPlateauInner(newIndex)
}
peaksPlateauInner(1)
return peaks
}