From 2e084edc9b4fb7e22d06c6050bd74538dcc2c5b6 Mon Sep 17 00:00:00 2001 From: Igor Dunaev Date: Mon, 20 May 2024 17:02:45 +0300 Subject: [PATCH] Implement generic algorithm --- .../space/kscience/kmath/series/emdDemo.kt | 10 +- .../series/EmpiricalModeDecomposition.kt | 118 ++++++++++-------- 2 files changed, 70 insertions(+), 58 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/series/emdDemo.kt b/examples/src/main/kotlin/space/kscience/kmath/series/emdDemo.kt index c69b163cb..a46f3768a 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/series/emdDemo.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/series/emdDemo.kt @@ -5,20 +5,24 @@ package space.kscience.kmath.series +import space.kscience.kmath.operations.algebra +import space.kscience.kmath.operations.bufferAlgebra import space.kscience.kmath.structures.* import space.kscience.kmath.operations.invoke import space.kscience.plotly.* import space.kscience.plotly.models.Scatter import kotlin.math.sin -fun main(): Unit = (Double.seriesAlgebra()) { +private val customAlgebra = (Double.algebra.bufferAlgebra) { SeriesAlgebra(this) { it.toDouble() } } + +fun main(): Unit = (customAlgebra) { val signal = DoubleArray(800) { sin(it.toDouble() / 10.0) + 3.5 * sin(it.toDouble() / 60.0) }.asBuffer().moveTo(0) - val emd = empiricalModeDecomposition( sConditionThreshold = 1, maxSiftIterations = 15, + siftingDelta = 1e-2, nModes = 4 ).decompose(signal) println("EMD: ${emd.modes.size} modes extracted, terminated because ${emd.terminatedBecause}") @@ -26,7 +30,7 @@ fun main(): Unit = (Double.seriesAlgebra()) { fun Plot.series(name: String, buffer: Buffer, block: Scatter.() -> Unit = {}) { this.scatter { this.name = name - this.x.numbers = buffer.offsetIndices + this.x.numbers = buffer.labels this.y.doubles = buffer.toDoubleArray() block() } diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt index ea8a0aa62..700b15757 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt @@ -8,11 +8,8 @@ package space.kscience.kmath.series import space.kscience.kmath.interpolation.SplineInterpolator import space.kscience.kmath.interpolation.interpolate import space.kscience.kmath.operations.* -import space.kscience.kmath.operations.Float64BufferOps.Companion.div -import space.kscience.kmath.operations.Float64BufferOps.Companion.pow import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asBuffer -import kotlin.math.sign +import space.kscience.kmath.structures.last /** * Empirical mode decomposition of a signal represented as a [Series]. @@ -29,13 +26,13 @@ import kotlin.math.sign * @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not * possible to extract more modes from the signal. */ -public class EmpiricalModeDecomposition ( - private val seriesAlgebra: SeriesAlgebra, +public class EmpiricalModeDecomposition, A: Field, BA, L: T> ( + private val seriesAlgebra: SeriesAlgebra, private val sConditionThreshold: Int = 15, private val maxSiftIterations: Int = 20, - private val siftingDelta: Double = 1e-2, + private val siftingDelta: T, private val nModes: Int = 6 -) where BA: BufferAlgebra, BA: RingOps> { +) where BA: BufferAlgebra, BA: FieldOps> { /** * Take a signal, construct an upper and a lower envelopes, find the mean value of two, @@ -45,37 +42,38 @@ public class EmpiricalModeDecomposition ( * @return mean [Series] or `null`. `null` is returned in case * the signal does not have enough extrema to construct envelopes. */ - private fun findMean(signal: Series): Series? = (seriesAlgebra) { - val interpolator = SplineInterpolator(Float64Field) - fun generateEnvelope(extrema: List, paddedExtremeValues: DoubleArray): Series { + private fun findMean(signal: Series): Series? = (seriesAlgebra) { + val interpolator = SplineInterpolator(elementAlgebra) + val makeBuffer = elementAlgebra.bufferFactory + fun generateEnvelope(extrema: List, paddedExtremeValues: Buffer): Series { val envelopeFunction = interpolator.interpolate( - Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() }, - paddedExtremeValues.asBuffer() + makeBuffer(extrema.size) { signal.labels[extrema[it]] }, + paddedExtremeValues ) return signal.mapWithLabel { _, label -> // For some reason PolynomialInterpolator is exclusive and the right boundary // TODO Notify interpolator authors - envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last() + envelopeFunction(label) ?: paddedExtremeValues.last() // need to make the interpolator yield values outside boundaries? } } // Extrema padding (experimental) TODO padding needs a dedicated function val maxima = listOf(0) + signal.peaks() + (signal.size - 1) - val maxValues = DoubleArray(maxima.size) { signal[maxima[it]] } + val maxValues = makeBuffer(maxima.size) { signal[maxima[it]] } if (maxValues[0] < maxValues[1]) { maxValues[0] = maxValues[1] } - if (maxValues.last() < maxValues[maxValues.lastIndex - 1]) { - maxValues[maxValues.lastIndex] = maxValues[maxValues.lastIndex - 1] + if (maxValues.last() < maxValues[maxValues.size - 2]) { + maxValues[maxValues.size - 1] = maxValues[maxValues.size - 2] } val minima = listOf(0) + signal.troughs() + (signal.size - 1) - val minValues = DoubleArray(minima.size) { signal[minima[it]] } + val minValues = makeBuffer(minima.size) { signal[minima[it]] } if (minValues[0] > minValues[1]) { minValues[0] = minValues[1] } - if (minValues.last() > minValues[minValues.lastIndex - 1]) { - minValues[minValues.lastIndex] = minValues[minValues.lastIndex - 1] + if (minValues.last() > minValues[minValues.size - 2]) { + minValues[minValues.size - 1] = minValues[minValues.size - 2] } return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return? val upperEnvelope = generateEnvelope(maxima, maxValues) @@ -92,13 +90,13 @@ public class EmpiricalModeDecomposition ( * @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode. * Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise. */ - private fun sift(signal: Series): SiftingResult = siftInner(signal, 1, 0) + private fun sift(signal: Series): SiftingResult = siftInner(signal, 1, 0) /** * Compute a single iteration of the sifting process. */ private tailrec fun siftInner( - prevMode: Series, + prevMode: Series, iterationNumber: Int, sNumber: Int ): SiftingResult = (seriesAlgebra) { @@ -106,11 +104,12 @@ public class EmpiricalModeDecomposition ( return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema else SiftingResult.SignalFlattened(prevMode) val mode = prevMode.zip(mean) { p, m -> p - m } - val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber + val newSNumber = if (sCondition(mode)) sNumber + 1 else sNumber return when { iterationNumber >= maxSiftIterations -> SiftingResult.MaxIterationsReached(mode) sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(mode) - relativeDifference(mode, prevMode) < siftingDelta * mode.size -> SiftingResult.DeltaReached(mode) + relativeDifference(mode, prevMode) < (elementAlgebra) { siftingDelta * mode.size } -> + SiftingResult.DeltaReached(mode) else -> siftInner(mode, iterationNumber + 1, newSNumber) } } @@ -123,8 +122,8 @@ public class EmpiricalModeDecomposition ( * Modes returned in a list which contains as many modes as it was possible * to extract before triggering one of the termination conditions. */ - public fun decompose(signal: Series): EMDecompositionResult = (seriesAlgebra) { - val modes = mutableListOf>() + public fun decompose(signal: Series): EMDecompositionResult = (seriesAlgebra) { + val modes = mutableListOf>() var residual = signal repeat(nModes) { val nextMode = when(val r = sift(residual)) { @@ -132,14 +131,15 @@ public class EmpiricalModeDecomposition ( return EMDecompositionResult( if (it == 0) EMDTerminationReason.SIGNAL_TOO_FLAT else EMDTerminationReason.ALL_POSSIBLE_MODES_EXTRACTED, - modes + modes, + residual ) - is SiftingResult.Success -> r.result + is SiftingResult.Success<*> -> r.result } - modes.add(nextMode) + modes.add(nextMode as Series) // TODO remove unchecked cast residual = residual.zip(nextMode) { l, r -> l - r } } - return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes) + return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes, residual) } } @@ -157,13 +157,13 @@ public class EmpiricalModeDecomposition ( * @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not * possible to extract more modes from the signal. */ -public fun SeriesAlgebra.empiricalModeDecomposition( +public fun , L: T, A: Field, BA> SeriesAlgebra.empiricalModeDecomposition( sConditionThreshold: Int = 15, maxSiftIterations: Int = 20, - siftingDelta: Double = 1e-2, + siftingDelta: T, nModes: Int = 3 -): EmpiricalModeDecomposition -where BA: BufferAlgebra, BA: RingOps> = EmpiricalModeDecomposition( +): EmpiricalModeDecomposition +where BA: BufferAlgebra, BA: FieldOps> = EmpiricalModeDecomposition( seriesAlgebra = this, sConditionThreshold = sConditionThreshold, maxSiftIterations = maxSiftIterations, @@ -174,12 +174,15 @@ where BA: BufferAlgebra, BA: RingOps> = EmpiricalModeD /** * Brute force count all zeros in the series. */ -private fun Series.countZeros(): Int { - require(size >= 2) { "Expected series with at least 2 elements, but got $size elements" } - data class SignCounter(val prevSign: Double, val zeroCount: Int) +private fun , A: Ring, BA> SeriesAlgebra.countZeros( + signal: Series +): Int where BA: BufferAlgebra, BA: FieldOps> { + require(signal.size >= 2) { "Expected series with at least 2 elements, but got ${signal.size} elements" } + data class SignCounter(val prevSign: Int, val zeroCount: Int) + fun strictSign(arg: T): Int = if (arg > elementAlgebra.zero) 1 else -1 - return fold(SignCounter(sign(get(0)), 0)) { acc: SignCounter, it: Double -> - val currentSign = sign(it) + return signal.fold(SignCounter(strictSign(signal[0]), 0)) { acc, it -> + val currentSign = strictSign(it) if (acc.prevSign != currentSign) SignCounter(currentSign, acc.zeroCount + 1) else SignCounter(currentSign, acc.zeroCount) }.zeroCount @@ -188,18 +191,19 @@ private fun Series.countZeros(): Int { /** * Compute relative difference of two series. */ -private fun SeriesAlgebra.relativeDifference( - current: Series, - previous: Series -):Double where BA: BufferAlgebra, BA: RingOps> = - (current - previous).pow(2) - .div(previous pow 2) - .fold(0.0) { acc, d -> acc + d } // TODO replace with Series<>.sum() method when it's implemented +private fun , BA> SeriesAlgebra.relativeDifference( + current: Series, + previous: Series +): T where BA: BufferAlgebra, BA: FieldOps> = (bufferAlgebra) { + ((current - previous) * (current - previous)) + .div(previous * previous) + .fold(elementAlgebra.zero) { acc, it -> acc + it} +} /** * Brute force count all extrema of a series. */ -private fun Series.countExtrema(): Int { +private fun > Series.countExtrema(): Int { require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" } return peaks().size + troughs().size } @@ -208,7 +212,10 @@ private fun Series.countExtrema(): Int { * Check whether the numbers of zeroes and extrema of a series differ by no more than 1. * This is a necessary condition of an empirical mode. */ -private fun Series.sCondition(): Boolean = (countExtrema() - countZeros()) in -1..1 +private fun , A: Ring, BA> SeriesAlgebra.sCondition( + signal: Series +): Boolean where BA: BufferAlgebra, BA: FieldOps> = + (signal.countExtrema() - countZeros(signal)) in -1..1 internal sealed interface SiftingResult { @@ -216,33 +223,33 @@ internal sealed interface SiftingResult { * Represents a condition when a mode has been successfully * extracted in a sifting process. */ - open class Success(val result: Series): SiftingResult + open class Success(val result: Series): SiftingResult /** * Returned when no termination condition was reached and the proto-mode * has become too flat (with not enough extrema to build envelopes) * after several sifting iterations. */ - class SignalFlattened(result: Series) : Success(result) + class SignalFlattened(result: Series) : Success(result) /** * Returned when sifting process has been terminated due to the * S-number condition being reached. */ - class SNumberReached(result: Series) : Success(result) + class SNumberReached(result: Series) : Success(result) /** * Returned when sifting process has been terminated due to the * delta condition (Cauchy criterion) being reached. */ - class DeltaReached(result: Series) : Success(result) + class DeltaReached(result: Series) : Success(result) /** * Returned when sifting process has been terminated after * executing the maximum number of iterations (specified when creating an instance * of [EmpiricalModeDecomposition]). */ - class MaxIterationsReached(result: Series): Success(result) + class MaxIterationsReached(result: Series): Success(result) /** * Returned when the submitted signal has not enough extrema to build envelopes, @@ -274,7 +281,8 @@ public enum class EMDTerminationReason { ALL_POSSIBLE_MODES_EXTRACTED } -public data class EMDecompositionResult( +public data class EMDecompositionResult( val terminatedBecause: EMDTerminationReason, - val modes: List> + val modes: List>, + val residual: Series ) \ No newline at end of file