forked from kscience/kmath
Part 2
This commit is contained in:
parent
2a2c5e8765
commit
d2c8423b6f
@ -9,8 +9,7 @@ import space.kscience.kmath.interpolation.SplineInterpolator
|
|||||||
import space.kscience.kmath.interpolation.interpolate
|
import space.kscience.kmath.interpolation.interpolate
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.asBuffer
|
import space.kscience.kmath.structures.last
|
||||||
import kotlin.math.sign
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Empirical mode decomposition of a signal represented as a [Series].
|
* Empirical mode decomposition of a signal represented as a [Series].
|
||||||
@ -45,35 +44,36 @@ public class EmpiricalModeDecomposition<A: Field<Double>, BA, L: Number> (
|
|||||||
*/
|
*/
|
||||||
private fun findMean(signal: Series<Double>): Series<Double>? = (seriesAlgebra) {
|
private fun findMean(signal: Series<Double>): Series<Double>? = (seriesAlgebra) {
|
||||||
val interpolator = SplineInterpolator(elementAlgebra)
|
val interpolator = SplineInterpolator(elementAlgebra)
|
||||||
fun generateEnvelope(extrema: List<Int>, paddedExtremeValues: Array<Double>): Series<Double> {
|
val makeBuffer = elementAlgebra.bufferFactory
|
||||||
|
fun generateEnvelope(extrema: List<Int>, paddedExtremeValues: Buffer<Double>): Series<Double> {
|
||||||
val envelopeFunction = interpolator.interpolate(
|
val envelopeFunction = interpolator.interpolate(
|
||||||
Buffer(extrema.size) { signal.labels[extrema[it]] as Double },
|
Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() },
|
||||||
paddedExtremeValues.asBuffer()
|
paddedExtremeValues
|
||||||
)
|
)
|
||||||
return signal.mapWithLabel { _, label ->
|
return signal.mapWithLabel { _, label ->
|
||||||
// For some reason PolynomialInterpolator is exclusive and the right boundary
|
// For some reason PolynomialInterpolator is exclusive and the right boundary
|
||||||
// TODO Notify interpolator authors
|
// TODO Notify interpolator authors
|
||||||
envelopeFunction(label as Double) ?: paddedExtremeValues.last()
|
envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last()
|
||||||
// need to make the interpolator yield values outside boundaries?
|
// need to make the interpolator yield values outside boundaries?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Extrema padding (experimental) TODO padding needs a dedicated function
|
// Extrema padding (experimental) TODO padding needs a dedicated function
|
||||||
val maxima = listOf(0) + signal.peaks() + (signal.size - 1)
|
val maxima = listOf(0) + signal.peaks() + (signal.size - 1)
|
||||||
val maxValues = Array(maxima.size) { signal[maxima[it]] }
|
val maxValues = makeBuffer(maxima.size) { signal[maxima[it]] }
|
||||||
if (maxValues[0] < maxValues[1]) {
|
if (maxValues[0] < maxValues[1]) {
|
||||||
maxValues[0] = maxValues[1]
|
maxValues[0] = maxValues[1]
|
||||||
}
|
}
|
||||||
if (maxValues.last() < maxValues[maxValues.lastIndex - 1]) {
|
if (maxValues.last() < maxValues[maxValues.size - 2]) {
|
||||||
maxValues[maxValues.lastIndex] = maxValues[maxValues.lastIndex - 1]
|
maxValues[maxValues.size - 1] = maxValues[maxValues.size - 2]
|
||||||
}
|
}
|
||||||
|
|
||||||
val minima = listOf(0) + signal.troughs() + (signal.size - 1)
|
val minima = listOf(0) + signal.troughs() + (signal.size - 1)
|
||||||
val minValues = Array(minima.size) { signal[minima[it]] }
|
val minValues = makeBuffer(minima.size) { signal[minima[it]] }
|
||||||
if (minValues[0] > minValues[1]) {
|
if (minValues[0] > minValues[1]) {
|
||||||
minValues[0] = minValues[1]
|
minValues[0] = minValues[1]
|
||||||
}
|
}
|
||||||
if (minValues.last() > minValues[minValues.lastIndex - 1]) {
|
if (minValues.last() > minValues[minValues.size - 2]) {
|
||||||
minValues[minValues.lastIndex] = minValues[minValues.lastIndex - 1]
|
minValues[minValues.size - 1] = minValues[minValues.size - 2]
|
||||||
}
|
}
|
||||||
return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return?
|
return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return?
|
||||||
val upperEnvelope = generateEnvelope(maxima, maxValues)
|
val upperEnvelope = generateEnvelope(maxima, maxValues)
|
||||||
@ -104,7 +104,7 @@ public class EmpiricalModeDecomposition<A: Field<Double>, BA, L: Number> (
|
|||||||
return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema
|
return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema
|
||||||
else SiftingResult.SignalFlattened(prevMode)
|
else SiftingResult.SignalFlattened(prevMode)
|
||||||
val mode = prevMode.zip(mean) { p, m -> p - m }
|
val mode = prevMode.zip(mean) { p, m -> p - m }
|
||||||
val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber
|
val newSNumber = if (sCondition(mode)) sNumber + 1 else sNumber
|
||||||
return when {
|
return when {
|
||||||
iterationNumber >= maxSiftIterations -> SiftingResult.MaxIterationsReached(mode)
|
iterationNumber >= maxSiftIterations -> SiftingResult.MaxIterationsReached(mode)
|
||||||
sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(mode)
|
sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(mode)
|
||||||
@ -121,7 +121,7 @@ public class EmpiricalModeDecomposition<A: Field<Double>, BA, L: Number> (
|
|||||||
* Modes returned in a list which contains as many modes as it was possible
|
* Modes returned in a list which contains as many modes as it was possible
|
||||||
* to extract before triggering one of the termination conditions.
|
* to extract before triggering one of the termination conditions.
|
||||||
*/
|
*/
|
||||||
public fun decompose(signal: Series<Double>): EMDecompositionResult = (seriesAlgebra) {
|
public fun decompose(signal: Series<Double>): EMDecompositionResult<Double> = (seriesAlgebra) {
|
||||||
val modes = mutableListOf<Series<Double>>()
|
val modes = mutableListOf<Series<Double>>()
|
||||||
var residual = signal
|
var residual = signal
|
||||||
repeat(nModes) {
|
repeat(nModes) {
|
||||||
@ -132,9 +132,9 @@ public class EmpiricalModeDecomposition<A: Field<Double>, BA, L: Number> (
|
|||||||
else EMDTerminationReason.ALL_POSSIBLE_MODES_EXTRACTED,
|
else EMDTerminationReason.ALL_POSSIBLE_MODES_EXTRACTED,
|
||||||
modes
|
modes
|
||||||
)
|
)
|
||||||
is SiftingResult.Success -> r.result
|
is SiftingResult.Success<*> -> r.result
|
||||||
}
|
}
|
||||||
modes.add(nextMode)
|
modes.add(nextMode as Series<Double>) // TODO remove unchecked cast
|
||||||
residual = residual.zip(nextMode) { l, r -> l - r }
|
residual = residual.zip(nextMode) { l, r -> l - r }
|
||||||
}
|
}
|
||||||
return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes)
|
return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes)
|
||||||
@ -172,12 +172,15 @@ where BA: BufferAlgebra<Double, A>, BA: FieldOps<Buffer<Double>> = EmpiricalMode
|
|||||||
/**
|
/**
|
||||||
* Brute force count all zeros in the series.
|
* Brute force count all zeros in the series.
|
||||||
*/
|
*/
|
||||||
private fun Series<Double>.countZeros(): Int {
|
private fun <T: Comparable<T>, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.countZeros(
|
||||||
require(size >= 2) { "Expected series with at least 2 elements, but got $size elements" }
|
signal: Series<T>
|
||||||
data class SignCounter(val prevSign: Double, val zeroCount: Int)
|
): Int where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> {
|
||||||
|
require(signal.size >= 2) { "Expected series with at least 2 elements, but got ${signal.size} elements" }
|
||||||
|
data class SignCounter(val prevSign: Int, val zeroCount: Int)
|
||||||
|
fun strictSign(arg: T): Int = if (arg > elementAlgebra.zero) 1 else -1
|
||||||
|
|
||||||
return fold(SignCounter(sign(get(0)), 0)) { acc: SignCounter, it: Double ->
|
return signal.fold(SignCounter(strictSign(signal[0]), 0)) { acc, it ->
|
||||||
val currentSign = sign(it)
|
val currentSign = strictSign(it)
|
||||||
if (acc.prevSign != currentSign) SignCounter(currentSign, acc.zeroCount + 1)
|
if (acc.prevSign != currentSign) SignCounter(currentSign, acc.zeroCount + 1)
|
||||||
else SignCounter(currentSign, acc.zeroCount)
|
else SignCounter(currentSign, acc.zeroCount)
|
||||||
}.zeroCount
|
}.zeroCount
|
||||||
@ -198,7 +201,7 @@ private fun <T, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.relativeDifference(
|
|||||||
/**
|
/**
|
||||||
* Brute force count all extrema of a series.
|
* Brute force count all extrema of a series.
|
||||||
*/
|
*/
|
||||||
private fun Series<Double>.countExtrema(): Int {
|
private fun <T: Comparable<T>> Series<T>.countExtrema(): Int {
|
||||||
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
||||||
return peaks().size + troughs().size
|
return peaks().size + troughs().size
|
||||||
}
|
}
|
||||||
@ -207,7 +210,10 @@ private fun Series<Double>.countExtrema(): Int {
|
|||||||
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
|
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
|
||||||
* This is a necessary condition of an empirical mode.
|
* This is a necessary condition of an empirical mode.
|
||||||
*/
|
*/
|
||||||
private fun Series<Double>.sCondition(): Boolean = (countExtrema() - countZeros()) in -1..1
|
private fun <T: Comparable<T>, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.sCondition(
|
||||||
|
signal: Series<T>
|
||||||
|
): Boolean where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> =
|
||||||
|
(signal.countExtrema() - countZeros(signal)) in -1..1
|
||||||
|
|
||||||
internal sealed interface SiftingResult {
|
internal sealed interface SiftingResult {
|
||||||
|
|
||||||
@ -215,33 +221,33 @@ internal sealed interface SiftingResult {
|
|||||||
* Represents a condition when a mode has been successfully
|
* Represents a condition when a mode has been successfully
|
||||||
* extracted in a sifting process.
|
* extracted in a sifting process.
|
||||||
*/
|
*/
|
||||||
open class Success(val result: Series<Double>): SiftingResult
|
open class Success<T>(val result: Series<T>): SiftingResult
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returned when no termination condition was reached and the proto-mode
|
* Returned when no termination condition was reached and the proto-mode
|
||||||
* has become too flat (with not enough extrema to build envelopes)
|
* has become too flat (with not enough extrema to build envelopes)
|
||||||
* after several sifting iterations.
|
* after several sifting iterations.
|
||||||
*/
|
*/
|
||||||
class SignalFlattened(result: Series<Double>) : Success(result)
|
class SignalFlattened<T>(result: Series<T>) : Success<T>(result)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returned when sifting process has been terminated due to the
|
* Returned when sifting process has been terminated due to the
|
||||||
* S-number condition being reached.
|
* S-number condition being reached.
|
||||||
*/
|
*/
|
||||||
class SNumberReached(result: Series<Double>) : Success(result)
|
class SNumberReached<T>(result: Series<T>) : Success<T>(result)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returned when sifting process has been terminated due to the
|
* Returned when sifting process has been terminated due to the
|
||||||
* delta condition (Cauchy criterion) being reached.
|
* delta condition (Cauchy criterion) being reached.
|
||||||
*/
|
*/
|
||||||
class DeltaReached(result: Series<Double>) : Success(result)
|
class DeltaReached<T>(result: Series<T>) : Success<T>(result)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returned when sifting process has been terminated after
|
* Returned when sifting process has been terminated after
|
||||||
* executing the maximum number of iterations (specified when creating an instance
|
* executing the maximum number of iterations (specified when creating an instance
|
||||||
* of [EmpiricalModeDecomposition]).
|
* of [EmpiricalModeDecomposition]).
|
||||||
*/
|
*/
|
||||||
class MaxIterationsReached(result: Series<Double>): Success(result)
|
class MaxIterationsReached<T>(result: Series<T>): Success<T>(result)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returned when the submitted signal has not enough extrema to build envelopes,
|
* Returned when the submitted signal has not enough extrema to build envelopes,
|
||||||
@ -273,7 +279,7 @@ public enum class EMDTerminationReason {
|
|||||||
ALL_POSSIBLE_MODES_EXTRACTED
|
ALL_POSSIBLE_MODES_EXTRACTED
|
||||||
}
|
}
|
||||||
|
|
||||||
public data class EMDecompositionResult(
|
public data class EMDecompositionResult<T>(
|
||||||
val terminatedBecause: EMDTerminationReason,
|
val terminatedBecause: EMDTerminationReason,
|
||||||
val modes: List<Series<Double>>
|
val modes: List<Series<T>>
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user