Improved peak finding algorithm
This commit is contained in:
parent
40ffa8d6fc
commit
597c27e615
@ -0,0 +1,89 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2024 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.series
|
||||||
|
|
||||||
|
|
||||||
|
public enum class PlateauEdgePolicy {
|
||||||
|
/**
|
||||||
|
* Midpoints of plateau are returned, edges not belonging to a plateau are ignored.
|
||||||
|
*
|
||||||
|
* A midpoint is the index closest to the average of indices of the left and right edges
|
||||||
|
* of the plateau:
|
||||||
|
*
|
||||||
|
* `val midpoint = ((leftEdge + rightEdge) / 2).toInt`
|
||||||
|
*/
|
||||||
|
AVERAGE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Both left and right edges are returned.
|
||||||
|
*/
|
||||||
|
KEEP_ALL_EDGES,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Only right edges are returned.
|
||||||
|
*/
|
||||||
|
KEEP_RIGHT_EDGES,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Only left edges are returned.
|
||||||
|
*/
|
||||||
|
KEEP_LEFT_EDGES,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ignore plateau, only peaks (troughs) with values strictly greater (less)
|
||||||
|
* than values of the adjacent points are returned.
|
||||||
|
*/
|
||||||
|
IGNORE
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T: Comparable<T>> Series<T>.peaks(
|
||||||
|
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
|
||||||
|
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this > other }, { other -> this >= other })
|
||||||
|
|
||||||
|
public fun <T: Comparable<T>> Series<T>.troughs(
|
||||||
|
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
|
||||||
|
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this < other }, { other -> this <= other })
|
||||||
|
|
||||||
|
|
||||||
|
private fun <T: Comparable<T>> Series<T>.findPeaks(
|
||||||
|
plateauPolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE,
|
||||||
|
cmpStrong: T.(T) -> Boolean,
|
||||||
|
cmpWeak: T.(T) -> Boolean
|
||||||
|
): List<Int> {
|
||||||
|
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
||||||
|
if (plateauPolicy == PlateauEdgePolicy.AVERAGE) return peaksWithPlateau(cmpStrong)
|
||||||
|
fun peakCriterion(left: T, middle: T, right: T): Boolean = when(plateauPolicy) {
|
||||||
|
PlateauEdgePolicy.KEEP_LEFT_EDGES -> middle.cmpStrong(left) && middle.cmpWeak(right)
|
||||||
|
PlateauEdgePolicy.KEEP_RIGHT_EDGES -> middle.cmpWeak(left) && middle.cmpStrong(right)
|
||||||
|
PlateauEdgePolicy.KEEP_ALL_EDGES ->
|
||||||
|
(middle.cmpStrong(left) && middle.cmpWeak(right)) || (middle.cmpWeak(left) && middle.cmpStrong(right))
|
||||||
|
else -> middle.cmpStrong(right) && middle.cmpStrong(left)
|
||||||
|
}
|
||||||
|
val indices = mutableListOf<Int>()
|
||||||
|
for (index in 1 .. size - 2) {
|
||||||
|
val left = this[index - 1]
|
||||||
|
val middle = this[index]
|
||||||
|
val right = this[index + 1]
|
||||||
|
if (peakCriterion(left, middle, right)) indices.add(index)
|
||||||
|
}
|
||||||
|
return indices
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private fun <T: Comparable<T>> Series<T>.peaksWithPlateau(cmpStrong: T.(T) -> Boolean): List<Int> {
|
||||||
|
val peaks = mutableListOf<Int>()
|
||||||
|
tailrec fun peaksPlateauInner(index: Int) {
|
||||||
|
val nextUnequal = (index + 1 ..< size).firstOrNull { this[it] != this[index] } ?: (size - 1)
|
||||||
|
val newIndex = if (this[index].cmpStrong(this[index - 1]) && this[index].cmpStrong(this[nextUnequal])) {
|
||||||
|
peaks.add((index + nextUnequal) / 2)
|
||||||
|
nextUnequal
|
||||||
|
} else index + 1
|
||||||
|
if (newIndex < size - 1) peaksPlateauInner(newIndex)
|
||||||
|
}
|
||||||
|
peaksPlateauInner(1)
|
||||||
|
return peaks
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user