diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt index a6de7a75a..f5f77e4d1 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt @@ -9,8 +9,7 @@ import space.kscience.kmath.interpolation.SplineInterpolator import space.kscience.kmath.interpolation.interpolate import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.asBuffer -import kotlin.math.sign +import space.kscience.kmath.structures.last /** * Empirical mode decomposition of a signal represented as a [Series]. @@ -45,35 +44,36 @@ public class EmpiricalModeDecomposition, BA, L: Number> ( */ private fun findMean(signal: Series): Series? = (seriesAlgebra) { val interpolator = SplineInterpolator(elementAlgebra) - fun generateEnvelope(extrema: List, paddedExtremeValues: Array): Series { + val makeBuffer = elementAlgebra.bufferFactory + fun generateEnvelope(extrema: List, paddedExtremeValues: Buffer): Series { val envelopeFunction = interpolator.interpolate( - Buffer(extrema.size) { signal.labels[extrema[it]] as Double }, - paddedExtremeValues.asBuffer() + Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() }, + paddedExtremeValues ) return signal.mapWithLabel { _, label -> // For some reason PolynomialInterpolator is exclusive and the right boundary // TODO Notify interpolator authors - envelopeFunction(label as Double) ?: paddedExtremeValues.last() + envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last() // need to make the interpolator yield values outside boundaries? } } // Extrema padding (experimental) TODO padding needs a dedicated function val maxima = listOf(0) + signal.peaks() + (signal.size - 1) - val maxValues = Array(maxima.size) { signal[maxima[it]] } + val maxValues = makeBuffer(maxima.size) { signal[maxima[it]] } if (maxValues[0] < maxValues[1]) { maxValues[0] = maxValues[1] } - if (maxValues.last() < maxValues[maxValues.lastIndex - 1]) { - maxValues[maxValues.lastIndex] = maxValues[maxValues.lastIndex - 1] + if (maxValues.last() < maxValues[maxValues.size - 2]) { + maxValues[maxValues.size - 1] = maxValues[maxValues.size - 2] } val minima = listOf(0) + signal.troughs() + (signal.size - 1) - val minValues = Array(minima.size) { signal[minima[it]] } + val minValues = makeBuffer(minima.size) { signal[minima[it]] } if (minValues[0] > minValues[1]) { minValues[0] = minValues[1] } - if (minValues.last() > minValues[minValues.lastIndex - 1]) { - minValues[minValues.lastIndex] = minValues[minValues.lastIndex - 1] + if (minValues.last() > minValues[minValues.size - 2]) { + minValues[minValues.size - 1] = minValues[minValues.size - 2] } return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return? val upperEnvelope = generateEnvelope(maxima, maxValues) @@ -104,7 +104,7 @@ public class EmpiricalModeDecomposition, BA, L: Number> ( return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema else SiftingResult.SignalFlattened(prevMode) val mode = prevMode.zip(mean) { p, m -> p - m } - val newSNumber = if (mode.sCondition()) sNumber + 1 else sNumber + val newSNumber = if (sCondition(mode)) sNumber + 1 else sNumber return when { iterationNumber >= maxSiftIterations -> SiftingResult.MaxIterationsReached(mode) sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(mode) @@ -121,7 +121,7 @@ public class EmpiricalModeDecomposition, BA, L: Number> ( * Modes returned in a list which contains as many modes as it was possible * to extract before triggering one of the termination conditions. */ - public fun decompose(signal: Series): EMDecompositionResult = (seriesAlgebra) { + public fun decompose(signal: Series): EMDecompositionResult = (seriesAlgebra) { val modes = mutableListOf>() var residual = signal repeat(nModes) { @@ -132,9 +132,9 @@ public class EmpiricalModeDecomposition, BA, L: Number> ( else EMDTerminationReason.ALL_POSSIBLE_MODES_EXTRACTED, modes ) - is SiftingResult.Success -> r.result + is SiftingResult.Success<*> -> r.result } - modes.add(nextMode) + modes.add(nextMode as Series) // TODO remove unchecked cast residual = residual.zip(nextMode) { l, r -> l - r } } return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes) @@ -172,12 +172,15 @@ where BA: BufferAlgebra, BA: FieldOps> = EmpiricalMode /** * Brute force count all zeros in the series. */ -private fun Series.countZeros(): Int { - require(size >= 2) { "Expected series with at least 2 elements, but got $size elements" } - data class SignCounter(val prevSign: Double, val zeroCount: Int) +private fun , A: Ring, BA> SeriesAlgebra.countZeros( + signal: Series +): Int where BA: BufferAlgebra, BA: FieldOps> { + require(signal.size >= 2) { "Expected series with at least 2 elements, but got ${signal.size} elements" } + data class SignCounter(val prevSign: Int, val zeroCount: Int) + fun strictSign(arg: T): Int = if (arg > elementAlgebra.zero) 1 else -1 - return fold(SignCounter(sign(get(0)), 0)) { acc: SignCounter, it: Double -> - val currentSign = sign(it) + return signal.fold(SignCounter(strictSign(signal[0]), 0)) { acc, it -> + val currentSign = strictSign(it) if (acc.prevSign != currentSign) SignCounter(currentSign, acc.zeroCount + 1) else SignCounter(currentSign, acc.zeroCount) }.zeroCount @@ -189,7 +192,7 @@ private fun Series.countZeros(): Int { private fun , BA> SeriesAlgebra.relativeDifference( current: Series, previous: Series -):T where BA: BufferAlgebra, BA: FieldOps> = (bufferAlgebra) { +): T where BA: BufferAlgebra, BA: FieldOps> = (bufferAlgebra) { ((current - previous) * (current - previous)) .div(previous * previous) .fold(elementAlgebra.zero) { acc, it -> acc + it} @@ -198,7 +201,7 @@ private fun , BA> SeriesAlgebra.relativeDifference( /** * Brute force count all extrema of a series. */ -private fun Series.countExtrema(): Int { +private fun > Series.countExtrema(): Int { require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" } return peaks().size + troughs().size } @@ -207,7 +210,10 @@ private fun Series.countExtrema(): Int { * Check whether the numbers of zeroes and extrema of a series differ by no more than 1. * This is a necessary condition of an empirical mode. */ -private fun Series.sCondition(): Boolean = (countExtrema() - countZeros()) in -1..1 +private fun , A: Ring, BA> SeriesAlgebra.sCondition( + signal: Series +): Boolean where BA: BufferAlgebra, BA: FieldOps> = + (signal.countExtrema() - countZeros(signal)) in -1..1 internal sealed interface SiftingResult { @@ -215,33 +221,33 @@ internal sealed interface SiftingResult { * Represents a condition when a mode has been successfully * extracted in a sifting process. */ - open class Success(val result: Series): SiftingResult + open class Success(val result: Series): SiftingResult /** * Returned when no termination condition was reached and the proto-mode * has become too flat (with not enough extrema to build envelopes) * after several sifting iterations. */ - class SignalFlattened(result: Series) : Success(result) + class SignalFlattened(result: Series) : Success(result) /** * Returned when sifting process has been terminated due to the * S-number condition being reached. */ - class SNumberReached(result: Series) : Success(result) + class SNumberReached(result: Series) : Success(result) /** * Returned when sifting process has been terminated due to the * delta condition (Cauchy criterion) being reached. */ - class DeltaReached(result: Series) : Success(result) + class DeltaReached(result: Series) : Success(result) /** * Returned when sifting process has been terminated after * executing the maximum number of iterations (specified when creating an instance * of [EmpiricalModeDecomposition]). */ - class MaxIterationsReached(result: Series): Success(result) + class MaxIterationsReached(result: Series): Success(result) /** * Returned when the submitted signal has not enough extrema to build envelopes, @@ -273,7 +279,7 @@ public enum class EMDTerminationReason { ALL_POSSIBLE_MODES_EXTRACTED } -public data class EMDecompositionResult( +public data class EMDecompositionResult( val terminatedBecause: EMDTerminationReason, - val modes: List> + val modes: List> ) \ No newline at end of file