diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt index ea8a0aa62..a6de7a75a 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/EmpiricalModeDecomposition.kt @@ -8,8 +8,6 @@ package space.kscience.kmath.series import space.kscience.kmath.interpolation.SplineInterpolator import space.kscience.kmath.interpolation.interpolate import space.kscience.kmath.operations.* -import space.kscience.kmath.operations.Float64BufferOps.Companion.div -import space.kscience.kmath.operations.Float64BufferOps.Companion.pow import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.asBuffer import kotlin.math.sign @@ -29,13 +27,13 @@ import kotlin.math.sign * @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not * possible to extract more modes from the signal. */ -public class EmpiricalModeDecomposition ( - private val seriesAlgebra: SeriesAlgebra, +public class EmpiricalModeDecomposition, BA, L: Number> ( + private val seriesAlgebra: SeriesAlgebra, private val sConditionThreshold: Int = 15, private val maxSiftIterations: Int = 20, private val siftingDelta: Double = 1e-2, private val nModes: Int = 6 -) where BA: BufferAlgebra, BA: RingOps> { +) where BA: BufferAlgebra, BA: FieldOps> { /** * Take a signal, construct an upper and a lower envelopes, find the mean value of two, @@ -46,22 +44,22 @@ public class EmpiricalModeDecomposition ( * the signal does not have enough extrema to construct envelopes. */ private fun findMean(signal: Series): Series? = (seriesAlgebra) { - val interpolator = SplineInterpolator(Float64Field) - fun generateEnvelope(extrema: List, paddedExtremeValues: DoubleArray): Series { + val interpolator = SplineInterpolator(elementAlgebra) + fun generateEnvelope(extrema: List, paddedExtremeValues: Array): Series { val envelopeFunction = interpolator.interpolate( - Buffer(extrema.size) { signal.labels[extrema[it]].toDouble() }, + Buffer(extrema.size) { signal.labels[extrema[it]] as Double }, paddedExtremeValues.asBuffer() ) return signal.mapWithLabel { _, label -> // For some reason PolynomialInterpolator is exclusive and the right boundary // TODO Notify interpolator authors - envelopeFunction(label.toDouble()) ?: paddedExtremeValues.last() + envelopeFunction(label as Double) ?: paddedExtremeValues.last() // need to make the interpolator yield values outside boundaries? } } // Extrema padding (experimental) TODO padding needs a dedicated function val maxima = listOf(0) + signal.peaks() + (signal.size - 1) - val maxValues = DoubleArray(maxima.size) { signal[maxima[it]] } + val maxValues = Array(maxima.size) { signal[maxima[it]] } if (maxValues[0] < maxValues[1]) { maxValues[0] = maxValues[1] } @@ -70,7 +68,7 @@ public class EmpiricalModeDecomposition ( } val minima = listOf(0) + signal.troughs() + (signal.size - 1) - val minValues = DoubleArray(minima.size) { signal[minima[it]] } + val minValues = Array(minima.size) { signal[minima[it]] } if (minValues[0] > minValues[1]) { minValues[0] = minValues[1] } @@ -157,13 +155,13 @@ public class EmpiricalModeDecomposition ( * @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not * possible to extract more modes from the signal. */ -public fun SeriesAlgebra.empiricalModeDecomposition( +public fun , BA> SeriesAlgebra.empiricalModeDecomposition( sConditionThreshold: Int = 15, maxSiftIterations: Int = 20, siftingDelta: Double = 1e-2, nModes: Int = 3 -): EmpiricalModeDecomposition -where BA: BufferAlgebra, BA: RingOps> = EmpiricalModeDecomposition( +): EmpiricalModeDecomposition +where BA: BufferAlgebra, BA: FieldOps> = EmpiricalModeDecomposition( seriesAlgebra = this, sConditionThreshold = sConditionThreshold, maxSiftIterations = maxSiftIterations, @@ -188,13 +186,14 @@ private fun Series.countZeros(): Int { /** * Compute relative difference of two series. */ -private fun SeriesAlgebra.relativeDifference( - current: Series, - previous: Series -):Double where BA: BufferAlgebra, BA: RingOps> = - (current - previous).pow(2) - .div(previous pow 2) - .fold(0.0) { acc, d -> acc + d } // TODO replace with Series<>.sum() method when it's implemented +private fun , BA> SeriesAlgebra.relativeDifference( + current: Series, + previous: Series +):T where BA: BufferAlgebra, BA: FieldOps> = (bufferAlgebra) { + ((current - previous) * (current - previous)) + .div(previous * previous) + .fold(elementAlgebra.zero) { acc, it -> acc + it} +} /** * Brute force count all extrema of a series.