forked from kscience/kmath
Compare commits
5 Commits
4f2ebb17ff
...
47a9bf0e9a
Author | SHA1 | Date | |
---|---|---|---|
47a9bf0e9a | |||
b8bf56bc42 | |||
597c27e615 | |||
40ffa8d6fc | |||
1c5113da29 |
@ -11,7 +11,7 @@ import space.kscience.kmath.operations.*
|
|||||||
import space.kscience.kmath.operations.Float64BufferOps.Companion.div
|
import space.kscience.kmath.operations.Float64BufferOps.Companion.div
|
||||||
import space.kscience.kmath.operations.Float64BufferOps.Companion.pow
|
import space.kscience.kmath.operations.Float64BufferOps.Companion.pow
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.last
|
import space.kscience.kmath.structures.asBuffer
|
||||||
import kotlin.math.sign
|
import kotlin.math.sign
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -47,24 +47,40 @@ public class EmpiricalModeDecomposition<BA, L: Number> (
|
|||||||
*/
|
*/
|
||||||
private fun findMean(signal: Series<Double>): Series<Double>? = (seriesAlgebra) {
|
private fun findMean(signal: Series<Double>): Series<Double>? = (seriesAlgebra) {
|
||||||
val interpolator = SplineInterpolator(Float64Field)
|
val interpolator = SplineInterpolator(Float64Field)
|
||||||
fun generateEnvelope(extrema: List<Int>): Series<Double> {
|
fun generateEnvelope(extrema: List<Int>, paddedExtremeValues: DoubleArray): Series<Double> {
|
||||||
val envelopeFunction = interpolator.interpolate(
|
val envelopeFunction = interpolator.interpolate(
|
||||||
Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() },
|
Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() },
|
||||||
Buffer(extrema.size) { signal[extrema[it]] }
|
paddedExtremeValues.asBuffer()
|
||||||
)
|
)
|
||||||
return signal.mapWithLabel { _, label ->
|
return signal.mapWithLabel { _, label ->
|
||||||
// For some reason PolynomialInterpolator is exclusive and the right boundary
|
// For some reason PolynomialInterpolator is exclusive and the right boundary
|
||||||
// TODO Notify interpolator authors
|
// TODO Notify interpolator authors
|
||||||
envelopeFunction(label.toDouble()) ?: signal.last()
|
envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last()
|
||||||
// need to make the interpolator yield values outside boundaries?
|
// need to make the interpolator yield values outside boundaries?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val maxima = signal.paddedMaxima()
|
// Extrema padding (experimental) TODO padding needs a dedicated function
|
||||||
val minima = signal.paddedMinima()
|
val maxima = listOf(0) + signal.peaks() + (signal.size - 1)
|
||||||
return if (maxima.size < 3 || minima.size < 3) null else {
|
val maxValues = DoubleArray(maxima.size) { signal[maxima[it]] }
|
||||||
val upperEnvelope = generateEnvelope(maxima)
|
if (maxValues[0] < maxValues[1]) {
|
||||||
val lowerEnvelope = generateEnvelope(minima)
|
maxValues[0] = maxValues[1]
|
||||||
return upperEnvelope.zip(lowerEnvelope) { upper, lower -> upper + lower / 2.0 }
|
}
|
||||||
|
if (maxValues.last() < maxValues[maxValues.lastIndex - 1]) {
|
||||||
|
maxValues[maxValues.lastIndex] = maxValues[maxValues.lastIndex - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
val minima = listOf(0) + signal.troughs() + (signal.size - 1)
|
||||||
|
val minValues = DoubleArray(minima.size) { signal[minima[it]] }
|
||||||
|
if (minValues[0] > minValues[1]) {
|
||||||
|
minValues[0] = minValues[1]
|
||||||
|
}
|
||||||
|
if (minValues.last() > minValues[minValues.lastIndex - 1]) {
|
||||||
|
minValues[minValues.lastIndex] = minValues[minValues.lastIndex - 1]
|
||||||
|
}
|
||||||
|
return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return?
|
||||||
|
val upperEnvelope = generateEnvelope(maxima, maxValues)
|
||||||
|
val lowerEnvelope = generateEnvelope(minima, minValues)
|
||||||
|
return (upperEnvelope + lowerEnvelope).map { it * 0.5 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,27 +92,19 @@ public class EmpiricalModeDecomposition<BA, L: Number> (
|
|||||||
* @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode.
|
* @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode.
|
||||||
* Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise.
|
* Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise.
|
||||||
*/
|
*/
|
||||||
private fun sift(signal: Series<Double>): SiftingResult = (seriesAlgebra) {
|
private fun sift(signal: Series<Double>): SiftingResult = siftInner(signal, 1, 0)
|
||||||
val mean = findMean(signal) ?: return SiftingResult.NotEnoughExtrema
|
|
||||||
val protoMode = signal.zip(mean) { s, m -> s - m }
|
|
||||||
val sNumber = if (protoMode.sCondition()) 1 else 0
|
|
||||||
return when {
|
|
||||||
maxSiftIterations == 1 -> SiftingResult.MaxIterationsReached(protoMode)
|
|
||||||
sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(protoMode)
|
|
||||||
relativeDifference(protoMode, signal) < siftingDelta * signal.size -> SiftingResult.DeltaReached(protoMode)
|
|
||||||
else -> siftInner(protoMode, 2, sNumber)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute a single iteration of the sifting process.
|
* Compute a single iteration of the sifting process.
|
||||||
*/
|
*/
|
||||||
private fun siftInner(
|
private tailrec fun siftInner(
|
||||||
prevMode: Series<Double>,
|
prevMode: Series<Double>,
|
||||||
iterationNumber: Int,
|
iterationNumber: Int,
|
||||||
sNumber: Int
|
sNumber: Int
|
||||||
): SiftingResult = (seriesAlgebra) {
|
): SiftingResult = (seriesAlgebra) {
|
||||||
val mean = findMean(prevMode) ?: return SiftingResult.SignalFlattened(prevMode)
|
val mean = findMean(prevMode) ?:
|
||||||
|
return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema
|
||||||
|
else SiftingResult.SignalFlattened(prevMode)
|
||||||
val mode = prevMode.zip(mean) { p, m -> p - m }
|
val mode = prevMode.zip(mean) { p, m -> p - m }
|
||||||
val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber
|
val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber
|
||||||
return when {
|
return when {
|
||||||
@ -188,49 +196,14 @@ private fun <BA> SeriesAlgebra<Double, *, BA, *>.relativeDifference(
|
|||||||
.div(previous pow 2)
|
.div(previous pow 2)
|
||||||
.fold(0.0) { acc, d -> acc + d } // TODO replace with Series<>.sum() method when it's implemented
|
.fold(0.0) { acc, d -> acc + d } // TODO replace with Series<>.sum() method when it's implemented
|
||||||
|
|
||||||
private fun <T: Comparable<T>> isExtreme(prev: T, elem: T, next: T): Boolean =
|
|
||||||
(elem > prev && elem > next) || (elem < prev && elem < next)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Brute force count all extrema of a series.
|
* Brute force count all extrema of a series.
|
||||||
*/
|
*/
|
||||||
private fun Series<Double>.countExtrema(): Int {
|
private fun Series<Double>.countExtrema(): Int {
|
||||||
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
||||||
return (1 .. size - 2).count { isExtreme(this[it - 1], this[it], this[it + 1]) }
|
return peaks().size + troughs().size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Retrieve indices of knot points for spline interpolation matching the predicate.
|
|
||||||
* The first and the last points of a series are always included.
|
|
||||||
*/
|
|
||||||
private fun <T: Comparable<T>> Series<T>.knotPoints(predicate: (T, T, T) -> Boolean): List<Int> {
|
|
||||||
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
|
||||||
val points = mutableListOf(0)
|
|
||||||
for (index in 1 .. size - 2) {
|
|
||||||
val left = this[index - 1]
|
|
||||||
val middle = this[index]
|
|
||||||
val right = this[index + 1]
|
|
||||||
if (predicate(left, middle, right)) points.add(index)
|
|
||||||
}
|
|
||||||
points.add(size - 1)
|
|
||||||
return points
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Retrieve indices of knot points used to construct an upper envelope,
|
|
||||||
* namely maxima together with the first last point in a series.
|
|
||||||
*/
|
|
||||||
private fun <T: Comparable<T>> Series<T>.paddedMaxima(): List<Int> =
|
|
||||||
knotPoints { left, middle, right -> (middle > left && middle > right) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Retrieve indices of knot points used to construct a lower envelope,
|
|
||||||
* namely minima together with the first last point in a series.
|
|
||||||
*/
|
|
||||||
private fun <T: Comparable<T>> Series<T>.paddedMinima(): List<Int> =
|
|
||||||
knotPoints { left, middle, right -> (middle < left && middle < right) }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
|
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
|
||||||
* This is a necessary condition of an empirical mode.
|
* This is a necessary condition of an empirical mode.
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2024 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.series
|
||||||
|
|
||||||
|
|
||||||
|
public enum class PlateauEdgePolicy {
|
||||||
|
/**
|
||||||
|
* Midpoints of plateau are returned, edges not belonging to a plateau are ignored.
|
||||||
|
*
|
||||||
|
* A midpoint is the index closest to the average of indices of the left and right edges
|
||||||
|
* of the plateau:
|
||||||
|
*
|
||||||
|
* `val midpoint = ((leftEdge + rightEdge) / 2).toInt`
|
||||||
|
*/
|
||||||
|
AVERAGE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Both left and right edges are returned.
|
||||||
|
*/
|
||||||
|
KEEP_ALL_EDGES,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Only right edges are returned.
|
||||||
|
*/
|
||||||
|
KEEP_RIGHT_EDGES,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Only left edges are returned.
|
||||||
|
*/
|
||||||
|
KEEP_LEFT_EDGES,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ignore plateau, only peaks (troughs) with values strictly greater (less)
|
||||||
|
* than values of the adjacent points are returned.
|
||||||
|
*/
|
||||||
|
IGNORE
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T: Comparable<T>> Series<T>.peaks(
|
||||||
|
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
|
||||||
|
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this > other }, { other -> this >= other })
|
||||||
|
|
||||||
|
public fun <T: Comparable<T>> Series<T>.troughs(
|
||||||
|
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
|
||||||
|
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this < other }, { other -> this <= other })
|
||||||
|
|
||||||
|
|
||||||
|
private fun <T: Comparable<T>> Series<T>.findPeaks(
|
||||||
|
plateauPolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE,
|
||||||
|
cmpStrong: T.(T) -> Boolean,
|
||||||
|
cmpWeak: T.(T) -> Boolean
|
||||||
|
): List<Int> {
|
||||||
|
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
||||||
|
if (plateauPolicy == PlateauEdgePolicy.AVERAGE) return peaksWithPlateau(cmpStrong)
|
||||||
|
fun peakCriterion(left: T, middle: T, right: T): Boolean = when(plateauPolicy) {
|
||||||
|
PlateauEdgePolicy.KEEP_LEFT_EDGES -> middle.cmpStrong(left) && middle.cmpWeak(right)
|
||||||
|
PlateauEdgePolicy.KEEP_RIGHT_EDGES -> middle.cmpWeak(left) && middle.cmpStrong(right)
|
||||||
|
PlateauEdgePolicy.KEEP_ALL_EDGES ->
|
||||||
|
(middle.cmpStrong(left) && middle.cmpWeak(right)) || (middle.cmpWeak(left) && middle.cmpStrong(right))
|
||||||
|
else -> middle.cmpStrong(right) && middle.cmpStrong(left)
|
||||||
|
}
|
||||||
|
val indices = mutableListOf<Int>()
|
||||||
|
for (index in 1 .. size - 2) {
|
||||||
|
val left = this[index - 1]
|
||||||
|
val middle = this[index]
|
||||||
|
val right = this[index + 1]
|
||||||
|
if (peakCriterion(left, middle, right)) indices.add(index)
|
||||||
|
}
|
||||||
|
return indices
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private fun <T: Comparable<T>> Series<T>.peaksWithPlateau(cmpStrong: T.(T) -> Boolean): List<Int> {
|
||||||
|
val peaks = mutableListOf<Int>()
|
||||||
|
tailrec fun peaksPlateauInner(index: Int) {
|
||||||
|
val nextUnequal = (index + 1 ..< size).firstOrNull { this[it] != this[index] } ?: (size - 1)
|
||||||
|
val newIndex = if (this[index].cmpStrong(this[index - 1]) && this[index].cmpStrong(this[nextUnequal])) {
|
||||||
|
peaks.add((index + nextUnequal) / 2)
|
||||||
|
nextUnequal
|
||||||
|
} else index + 1
|
||||||
|
if (newIndex < size - 1) peaksPlateauInner(newIndex)
|
||||||
|
}
|
||||||
|
peaksPlateauInner(1)
|
||||||
|
return peaks
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user