Compare commits

..

11 Commits

26 changed files with 507 additions and 93 deletions

View File

@ -0,0 +1,45 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.series
import space.kscience.kmath.operations.algebra
import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.structures.*
import space.kscience.kmath.operations.invoke
import space.kscience.plotly.*
import space.kscience.plotly.models.Scatter
import kotlin.math.sin
private val customAlgebra = (Double.algebra.bufferAlgebra) { SeriesAlgebra(this) { it.toDouble() } }
fun main(): Unit = (customAlgebra) {
val signal = DoubleArray(800) {
sin(it.toDouble() / 10.0) + 3.5 * sin(it.toDouble() / 60.0)
}.asBuffer().moveTo(0)
val emd = empiricalModeDecomposition(
sConditionThreshold = 1,
maxSiftIterations = 15,
siftingDelta = 1e-2,
nModes = 4
).decompose(signal)
println("EMD: ${emd.modes.size} modes extracted, terminated because ${emd.terminatedBecause}")
fun Plot.series(name: String, buffer: Buffer<Double>, block: Scatter.() -> Unit = {}) {
this.scatter {
this.name = name
this.x.numbers = buffer.labels
this.y.doubles = buffer.toDoubleArray()
block()
}
}
Plotly.plot {
series("Signal", signal)
emd.modes.forEachIndexed { index, it ->
series("Mode ${index+1}", it)
}
}.makeFile(resourceLocation = ResourceLocation.REMOTE)
}

View File

@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, Float64
@OptIn(PerformancePitfall::class)
private val StructureND<Double>.buffer: Float64Buffer
get() = when {
shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException(
!shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException(
this@StreamDoubleFieldND.shape,
shape
)

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
@ -61,7 +62,7 @@ fun main() {
// figure out MSE of approximation
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
require(yTrue.shape.size == 1)
require(yTrue.shape == yPred.shape)
require(yTrue.shape contentEquals yPred.shape)
val diff = yTrue - yPred
return sqrt(diff.dot(diff)).value()

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.asIterable
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.*
@ -93,7 +94,7 @@ class Dense(
// simple accuracy equal to the proportion of correct answers
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
check(yPred.shape == yTrue.shape)
check(yPred.shape contentEquals yTrue.shape)
val n = yPred.shape[0]
var correctCnt = 0
for (i in 0 until n) {

View File

@ -32,12 +32,12 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
}
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}

View File

@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
}
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}

View File

@ -55,7 +55,7 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
*/
@PerformancePitfall("Very slow on remote execution algebras")
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
require(left.shape == right.shape) {
require(left.shape.contentEquals(right.shape)) {
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
}
return structureND(left.shape) { index ->

View File

@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is ColumnStrides) return false
return shape == other.shape
return shape.contentEquals(other.shape)
}
override fun hashCode(): Int = shape.hashCode()
override fun hashCode(): Int = shape.contentHashCode()
public companion object
@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is RowStrides) return false
return shape == other.shape
return shape.contentEquals(other.shape)
}
override fun hashCode(): Int = shape.hashCode()
override fun hashCode(): Int = shape.contentHashCode()
public companion object

View File

@ -11,30 +11,19 @@ import kotlin.jvm.JvmInline
/**
* A read-only ND shape
*/
public class ShapeND(@PublishedApi internal val array: IntArray) {
@JvmInline
public value class ShapeND(@PublishedApi internal val array: IntArray) {
public val size: Int get() = array.size
public operator fun get(index: Int): Int = array[index]
override fun toString(): String = array.contentToString()
override fun hashCode(): Int = array.contentHashCode()
override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array)
}
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
@Deprecated(
message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.",
replaceWith = ReplaceWith("this == other"),
level = DeprecationLevel.WARNING,
)
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
@Deprecated(
message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.",
replaceWith = ReplaceWith("this.hashCode()"),
level = DeprecationLevel.WARNING,
)
public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
public val ShapeND.indices: IntRange get() = array.indices

View File

@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
}
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D()
}
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D()
}

View File

@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
@OptIn(PerformancePitfall::class)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, dataType)
@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
@OptIn(PerformancePitfall::class)
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1)) {
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) {
get(intArrayOf(0))
} else null
@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
}
val mt = asMultik().array
return if (ShapeND(mt.shape) == shape) {
return if (ShapeND(mt.shape).contentEquals(shape)) {
mt
} else {
@OptIn(UnsafeKMathAPI::class)

View File

@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
right: StructureND<T>,
transform: C.(T, T) -> T,
): Nd4jArrayStructure<T> {
require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
val new = Nd4j.create(*left.shape.asArray()).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
return new

View File

@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
@OptIn(PerformancePitfall::class)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
require(left.shape == right.shape)
require(left.shape.contentEquals(right.shape))
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
}
@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, Float64Field>
}
override fun StructureND<Double>.valueOrNull(): Double? =
if (shape == ShapeND(1)) ndArray.getDouble(0) else null
if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null
// TODO rewrite
override fun diagonalEmbedding(

View File

@ -14,6 +14,7 @@ kotlin.sourceSets {
dependencies {
api(projects.kmathCoroutines)
//implementation(spclibs.atomicfu)
api(project(":kmath-functions"))
}
}

View File

@ -0,0 +1,288 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.series
import space.kscience.kmath.interpolation.SplineInterpolator
import space.kscience.kmath.interpolation.interpolate
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.last
/**
* Empirical mode decomposition of a signal represented as a [Series].
*
* @param seriesAlgebra context to perform operations in.
* @param maxSiftIterations number of iterations in the mode extraction (sifting) process after which
* the result is returned even if no other conditions are met.
* @param sConditionThreshold one of the possible criteria for sifting process termination:
* how many times in a row should a proto-mode satisfy the s-condition (the number of zeros differs from
* the number of extrema no more than by 1) to be considered an empirical mode.
* @param siftingDelta one of the possible criteria for sifting process termination: if relative difference of
* two proto-modes obtained in a sequence is less that this number, the last proto-mode is considered
* an empirical mode and returned.
* @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not
* possible to extract more modes from the signal.
*/
public class EmpiricalModeDecomposition<T: Comparable<T>, A: Field<T>, BA, L: T> (
private val seriesAlgebra: SeriesAlgebra<T, A, BA, L>,
private val sConditionThreshold: Int = 15,
private val maxSiftIterations: Int = 20,
private val siftingDelta: T,
private val nModes: Int = 6
) where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> {
/**
* Take a signal, construct an upper and a lower envelopes, find the mean value of two,
* represented as a series.
* @param signal Signal to compute on
*
* @return mean [Series] or `null`. `null` is returned in case
* the signal does not have enough extrema to construct envelopes.
*/
private fun findMean(signal: Series<T>): Series<T>? = (seriesAlgebra) {
val interpolator = SplineInterpolator(elementAlgebra)
val makeBuffer = elementAlgebra.bufferFactory
fun generateEnvelope(extrema: List<Int>, paddedExtremeValues: Buffer<T>): Series<T> {
val envelopeFunction = interpolator.interpolate(
makeBuffer(extrema.size) { signal.labels[extrema[it]] },
paddedExtremeValues
)
return signal.mapWithLabel { _, label ->
// For some reason PolynomialInterpolator is exclusive and the right boundary
// TODO Notify interpolator authors
envelopeFunction(label) ?: paddedExtremeValues.last()
// need to make the interpolator yield values outside boundaries?
}
}
// Extrema padding (experimental) TODO padding needs a dedicated function
val maxima = listOf(0) + signal.peaks() + (signal.size - 1)
val maxValues = makeBuffer(maxima.size) { signal[maxima[it]] }
if (maxValues[0] < maxValues[1]) {
maxValues[0] = maxValues[1]
}
if (maxValues.last() < maxValues[maxValues.size - 2]) {
maxValues[maxValues.size - 1] = maxValues[maxValues.size - 2]
}
val minima = listOf(0) + signal.troughs() + (signal.size - 1)
val minValues = makeBuffer(minima.size) { signal[minima[it]] }
if (minValues[0] > minValues[1]) {
minValues[0] = minValues[1]
}
if (minValues.last() > minValues[minValues.size - 2]) {
minValues[minValues.size - 1] = minValues[minValues.size - 2]
}
return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return?
val upperEnvelope = generateEnvelope(maxima, maxValues)
val lowerEnvelope = generateEnvelope(minima, minValues)
return (upperEnvelope + lowerEnvelope).map { it * 0.5 }
}
}
/**
* Extract a single empirical mode from a signal. This process is called sifting, hence the name.
* @param signal Signal to extract a mode from. The first mode is extracted from the initial signal,
* subsequent modes are extracted from the residuals between the signal and all previous modes.
*
* @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode.
* Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise.
*/
private fun sift(signal: Series<T>): SiftingResult = siftInner(signal, 1, 0)
/**
* Compute a single iteration of the sifting process.
*/
private tailrec fun siftInner(
prevMode: Series<T>,
iterationNumber: Int,
sNumber: Int
): SiftingResult = (seriesAlgebra) {
val mean = findMean(prevMode) ?:
return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema
else SiftingResult.SignalFlattened(prevMode)
val mode = prevMode.zip(mean) { p, m -> p - m }
val newSNumber = if (sCondition(mode)) sNumber + 1 else sNumber
return when {
iterationNumber >= maxSiftIterations -> SiftingResult.MaxIterationsReached(mode)
sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(mode)
relativeDifference(mode, prevMode) < (elementAlgebra) { siftingDelta * mode.size } ->
SiftingResult.DeltaReached(mode)
else -> siftInner(mode, iterationNumber + 1, newSNumber)
}
}
/**
* Extract [nModes] empirical modes from a signal represented by a time series.
* @param signal Signal to decompose.
* @return [EMDecompositionResult] with an appropriate reason for algorithm termination
* (see [EMDTerminationReason] for possible reasons).
* Modes returned in a list which contains as many modes as it was possible
* to extract before triggering one of the termination conditions.
*/
public fun decompose(signal: Series<T>): EMDecompositionResult<T> = (seriesAlgebra) {
val modes = mutableListOf<Series<T>>()
var residual = signal
repeat(nModes) {
val nextMode = when(val r = sift(residual)) {
SiftingResult.NotEnoughExtrema ->
return EMDecompositionResult(
if (it == 0) EMDTerminationReason.SIGNAL_TOO_FLAT
else EMDTerminationReason.ALL_POSSIBLE_MODES_EXTRACTED,
modes,
residual
)
is SiftingResult.Success<*> -> r.result
}
modes.add(nextMode as Series<T>) // TODO remove unchecked cast
residual = residual.zip(nextMode) { l, r -> l - r }
}
return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes, residual)
}
}
/**
* Shortcut to retrieve a decomposition factory from a [SeriesAlgebra] scope.
* @receiver scope to perform operations in.
* @param maxSiftIterations number of iterations in the mode extraction (sifting) process after which
* the result is returned even if no other conditions are met.
* @param sConditionThreshold one of the possible criteria for sifting process termination:
* how many times in a row should a proto-mode satisfy the s-condition (the number of zeros differs from
* the number of extrema no more than by 1) to be considered an empirical mode.
* @param siftingDelta one of the possible criteria for sifting process termination: if relative difference of
* two proto-modes obtained in a sequence is less that this number, the last proto-mode is considered
* an empirical mode and returned.
* @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not
* possible to extract more modes from the signal.
*/
public fun <T: Comparable<T>, L: T, A: Field<T>, BA> SeriesAlgebra<T, A, BA, L>.empiricalModeDecomposition(
sConditionThreshold: Int = 15,
maxSiftIterations: Int = 20,
siftingDelta: T,
nModes: Int = 3
): EmpiricalModeDecomposition<T, A, BA, L>
where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> = EmpiricalModeDecomposition(
seriesAlgebra = this,
sConditionThreshold = sConditionThreshold,
maxSiftIterations = maxSiftIterations,
siftingDelta = siftingDelta,
nModes = nModes
)
/**
* Brute force count all zeros in the series.
*/
private fun <T: Comparable<T>, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.countZeros(
signal: Series<T>
): Int where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> {
require(signal.size >= 2) { "Expected series with at least 2 elements, but got ${signal.size} elements" }
data class SignCounter(val prevSign: Int, val zeroCount: Int)
fun strictSign(arg: T): Int = if (arg > elementAlgebra.zero) 1 else -1
return signal.fold(SignCounter(strictSign(signal[0]), 0)) { acc, it ->
val currentSign = strictSign(it)
if (acc.prevSign != currentSign) SignCounter(currentSign, acc.zeroCount + 1)
else SignCounter(currentSign, acc.zeroCount)
}.zeroCount
}
/**
* Compute relative difference of two series.
*/
private fun <T, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.relativeDifference(
current: Series<T>,
previous: Series<T>
): T where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> = (bufferAlgebra) {
((current - previous) * (current - previous))
.div(previous * previous)
.fold(elementAlgebra.zero) { acc, it -> acc + it}
}
/**
* Brute force count all extrema of a series.
*/
private fun <T: Comparable<T>> Series<T>.countExtrema(): Int {
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
return peaks().size + troughs().size
}
/**
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
* This is a necessary condition of an empirical mode.
*/
private fun <T: Comparable<T>, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.sCondition(
signal: Series<T>
): Boolean where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> =
(signal.countExtrema() - countZeros(signal)) in -1..1
internal sealed interface SiftingResult {
/**
* Represents a condition when a mode has been successfully
* extracted in a sifting process.
*/
open class Success<T>(val result: Series<T>): SiftingResult
/**
* Returned when no termination condition was reached and the proto-mode
* has become too flat (with not enough extrema to build envelopes)
* after several sifting iterations.
*/
class SignalFlattened<T>(result: Series<T>) : Success<T>(result)
/**
* Returned when sifting process has been terminated due to the
* S-number condition being reached.
*/
class SNumberReached<T>(result: Series<T>) : Success<T>(result)
/**
* Returned when sifting process has been terminated due to the
* delta condition (Cauchy criterion) being reached.
*/
class DeltaReached<T>(result: Series<T>) : Success<T>(result)
/**
* Returned when sifting process has been terminated after
* executing the maximum number of iterations (specified when creating an instance
* of [EmpiricalModeDecomposition]).
*/
class MaxIterationsReached<T>(result: Series<T>): Success<T>(result)
/**
* Returned when the submitted signal has not enough extrema to build envelopes,
* i.e. when [SignalFlattened] condition has already been reached before the first sifting iteration.
*/
data object NotEnoughExtrema : SiftingResult
}
public enum class EMDTerminationReason {
/**
* Returned when the signal is too flat, i.e. there are too few extrema
* to build envelopes necessary to extract modes.
*/
SIGNAL_TOO_FLAT,
/**
* Returned when there has been extracted as many modes as
* specified when creating the instance of [EmpiricalModeDecomposition]
*/
MAX_MODES_REACHED,
/**
* Returned when the algorithm terminated after finding impossible
* to extract more modes from the signal and the maximum number
* of modes (specified when creating an instance of [EmpiricalModeDecomposition])
* has not yet been reached.
*/
ALL_POSSIBLE_MODES_EXTRACTED
}
public data class EMDecompositionResult<T>(
val terminatedBecause: EMDTerminationReason,
val modes: List<Series<T>>,
val residual: Series<T>
)

View File

@ -169,7 +169,7 @@ public open class SeriesAlgebra<T, out A : Ring<T>, out BA : BufferAlgebra<T, A>
public inline fun Buffer<T>.mapWithLabel(crossinline transform: A.(arg: T, label: L) -> T): Series<T> {
val labels = labels
val buf = elementAlgebra.bufferFactory(size) {
elementAlgebra.transform(getByOffset(it), labels[it])
elementAlgebra.transform(get(it), labels[it])
}
return buf.moveTo(offsetIndices.first)
}

View File

@ -0,0 +1,89 @@
/*
* Copyright 2018-2024 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.series
public enum class PlateauEdgePolicy {
/**
* Midpoints of plateau are returned, edges not belonging to a plateau are ignored.
*
* A midpoint is the index closest to the average of indices of the left and right edges
* of the plateau:
*
* `val midpoint = ((leftEdge + rightEdge) / 2).toInt`
*/
AVERAGE,
/**
* Both left and right edges are returned.
*/
KEEP_ALL_EDGES,
/**
* Only right edges are returned.
*/
KEEP_RIGHT_EDGES,
/**
* Only left edges are returned.
*/
KEEP_LEFT_EDGES,
/**
* Ignore plateau, only peaks (troughs) with values strictly greater (less)
* than values of the adjacent points are returned.
*/
IGNORE
}
public fun <T: Comparable<T>> Series<T>.peaks(
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this > other }, { other -> this >= other })
public fun <T: Comparable<T>> Series<T>.troughs(
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this < other }, { other -> this <= other })
private fun <T: Comparable<T>> Series<T>.findPeaks(
plateauPolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE,
cmpStrong: T.(T) -> Boolean,
cmpWeak: T.(T) -> Boolean
): List<Int> {
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
if (plateauPolicy == PlateauEdgePolicy.AVERAGE) return peaksWithPlateau(cmpStrong)
fun peakCriterion(left: T, middle: T, right: T): Boolean = when(plateauPolicy) {
PlateauEdgePolicy.KEEP_LEFT_EDGES -> middle.cmpStrong(left) && middle.cmpWeak(right)
PlateauEdgePolicy.KEEP_RIGHT_EDGES -> middle.cmpWeak(left) && middle.cmpStrong(right)
PlateauEdgePolicy.KEEP_ALL_EDGES ->
(middle.cmpStrong(left) && middle.cmpWeak(right)) || (middle.cmpWeak(left) && middle.cmpStrong(right))
else -> middle.cmpStrong(right) && middle.cmpStrong(left)
}
val indices = mutableListOf<Int>()
for (index in 1 .. size - 2) {
val left = this[index - 1]
val middle = this[index]
val right = this[index + 1]
if (peakCriterion(left, middle, right)) indices.add(index)
}
return indices
}
private fun <T: Comparable<T>> Series<T>.peaksWithPlateau(cmpStrong: T.(T) -> Boolean): List<Int> {
val peaks = mutableListOf<Int>()
tailrec fun peaksPlateauInner(index: Int) {
val nextUnequal = (index + 1 ..< size).firstOrNull { this[it] != this[index] } ?: (size - 1)
val newIndex = if (this[index].cmpStrong(this[index - 1]) && this[index].cmpStrong(this[nextUnequal])) {
peaks.add((index + nextUnequal) / 2)
nextUnequal
} else index + 1
if (newIndex < size - 1) peaksPlateauInner(newIndex)
}
peaksPlateauInner(1)
return peaks
}

View File

@ -23,6 +23,7 @@ import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.asArray
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
@ -104,7 +105,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
protected abstract fun const(value: T): Constant<TT>
@OptIn(PerformancePitfall::class)
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1))
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1))
get(intArrayOf(0)) else null
/**

View File

@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra :
override fun StructureND<Double>.valueOrNull(): Double? {
val dt = asDoubleTensor()
return if (dt.shape == ShapeND(1)) dt.source[0] else null
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
}
override fun StructureND<Double>.value(): Double = valueOrNull()

View File

@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
override fun StructureND<Int>.valueOrNull(): Int? {
val dt = asIntTensor()
return if (dt.shape == ShapeND(1)) dt.source[0] else null
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
}
override fun StructureND<Int>.value(): Int = valueOrNull()
@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = ShapeND(tensors.size) + shape
// val resBuffer: List<Int> = tensors.flatMap {
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.nd.linearSize
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor
@ -31,7 +32,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) =
@PublishedApi
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
check(a.shape == b.shape) {
check(a.shape contentEquals b.shape) {
"Incompatible shapes ${a.shape} and ${b.shape} "
}

View File

@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = ShapeND(tensors.size) + shape
// val resBuffer: List<Double> = tensors.flatMap {
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot(
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape)
check(
luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Inappropriate shapes of input tensors" }

View File

@ -6,31 +6,29 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
import space.kscience.kmath.tensors.core.internal.broadcastShapes
import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class TestBroadcasting {
@Test
fun testBroadcastShapes() = DoubleTensorAlgebra {
assertEquals(
assertTrue(
broadcastShapes(
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
),
ShapeND(1, 2, 3)
) contentEquals ShapeND(1, 2, 3)
)
assertEquals(
assertTrue(
broadcastShapes(
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
),
ShapeND(5, 6, 7)
) contentEquals ShapeND(5, 6, 7)
)
}
@ -40,7 +38,7 @@ internal class TestBroadcasting {
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
val res = broadcastTo(tensor2, tensor1.shape)
assertTrue(res.shape == ShapeND(2, 3))
assertTrue(res.shape contentEquals ShapeND(2, 3))
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
}
@ -52,9 +50,9 @@ internal class TestBroadcasting {
val res = broadcastTensors(tensor1, tensor2, tensor3)
assertEquals(res[0].shape, ShapeND(1, 2, 3))
assertEquals(res[1].shape, ShapeND(1, 2, 3))
assertEquals(res[2].shape, ShapeND(1, 2, 3))
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
@ -69,9 +67,9 @@ internal class TestBroadcasting {
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertEquals(res[0].shape, ShapeND(1, 2, 3))
assertEquals(res[1].shape, ShapeND(1, 1, 3))
assertEquals(res[2].shape, ShapeND(1, 1, 1))
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3))
assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1))
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
@ -86,9 +84,9 @@ internal class TestBroadcasting {
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3))
assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3))
assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1))
assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3))
assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3))
assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1))
}
@Test
@ -101,16 +99,16 @@ internal class TestBroadcasting {
val tensor31 = tensor3 - tensor1
val tensor32 = tensor3 - tensor2
assertEquals(tensor21.shape, ShapeND(2, 3))
assertTrue(tensor21.shape contentEquals ShapeND(2, 3))
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertEquals(tensor31.shape, ShapeND(1, 2, 3))
assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3))
assertTrue(
tensor31.source
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
)
assertEquals(tensor32.shape, ShapeND(1, 1, 3))
assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3))
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.contentEquals
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.internal.svd1d
import kotlin.math.abs
@ -112,8 +113,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val (q, r) = qr(tensor)
assertEquals(q.shape, shape)
assertEquals(r.shape, shape)
assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape }
assertTrue((q matmul r).eq(tensor))
@ -132,9 +133,9 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val (p, l, u) = lu(tensor)
assertEquals(p.shape, shape)
assertEquals(l.shape, shape)
assertEquals(u.shape, shape)
assertTrue { p.shape contentEquals shape }
assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape }
assertTrue((p matmul tensor).eq(l matmul u))
}
@ -156,7 +157,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val res = svd1d(tensor2)
assertEquals(res.shape, ShapeND(2))
assertTrue(res.shape contentEquals ShapeND(2))
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
}

View File

@ -6,8 +6,7 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.assertBufferEquals
import kotlin.test.Test
@ -44,7 +43,7 @@ internal class TestDoubleTensorAlgebra {
val res = tensor.transposed(0, 0)
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
assertEquals(res.shape, ShapeND(1))
assertTrue(res.shape contentEquals ShapeND(1))
}
@Test
@ -53,7 +52,7 @@ internal class TestDoubleTensorAlgebra {
val res = tensor.transposed(1, 0)
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertEquals(res.shape, ShapeND(2, 3))
assertTrue(res.shape contentEquals ShapeND(2, 3))
}
@Test
@ -63,9 +62,9 @@ internal class TestDoubleTensorAlgebra {
val res02 = tensor.transposed(-3, 2)
val res12 = tensor.transposed()
assertEquals(res01.shape, ShapeND(2, 1, 3))
assertEquals(res02.shape, ShapeND(3, 2, 1))
assertEquals(res12.shape, ShapeND(1, 3, 2))
assertTrue(res01.shape contentEquals ShapeND(2, 1, 3))
assertTrue(res02.shape contentEquals ShapeND(3, 2, 1))
assertTrue(res12.shape contentEquals ShapeND(1, 3, 2))
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
@ -115,19 +114,19 @@ internal class TestDoubleTensorAlgebra {
val res12 = tensor1.dot(tensor2)
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
assertEquals(res12.shape, ShapeND(2))
assertTrue(res12.shape contentEquals ShapeND(2))
val res32 = tensor3.matmul(tensor2)
assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
assertEquals(res32.shape, ShapeND(1, 1))
assertTrue(res32.shape contentEquals ShapeND(1, 1))
val res22 = tensor2.dot(tensor2)
assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
assertEquals(res22.shape, ShapeND(1))
assertTrue(res22.shape contentEquals ShapeND(1))
val res11 = tensor1.dot(tensor11)
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertEquals(res11.shape, ShapeND(2, 2))
assertTrue(res11.shape contentEquals ShapeND(2, 2))
val res45 = tensor4.matmul(tensor5)
assertTrue(
@ -136,7 +135,7 @@ internal class TestDoubleTensorAlgebra {
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
)
)
assertEquals(res45.shape, ShapeND(2, 3, 3))
assertTrue(res45.shape contentEquals ShapeND(2, 3, 3))
}
@Test
@ -145,35 +144,35 @@ internal class TestDoubleTensorAlgebra {
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val tensor3 = zeros(ShapeND(2, 3, 4, 5))
assertEquals(
diagonalEmbedding(tensor3, 0, 3, 4).shape,
ShapeND(2, 3, 4, 5, 5)
assertTrue(
diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
ShapeND(2, 3, 4, 5, 5)
)
assertEquals(
diagonalEmbedding(tensor3, 1, 3, 4).shape,
ShapeND(2, 3, 4, 6, 6)
assertTrue(
diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
ShapeND(2, 3, 4, 6, 6)
)
assertEquals(
diagonalEmbedding(tensor3, 2, 0, 3).shape,
ShapeND(7, 2, 3, 7, 4)
assertTrue(
diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
ShapeND(7, 2, 3, 7, 4)
)
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
assertEquals(diagonal1.shape, ShapeND(3, 3))
assertTrue(diagonal1.shape contentEquals ShapeND(3, 3))
assertTrue(
diagonal1.source contentEquals
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
)
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
assertEquals(diagonal1Offset.shape, ShapeND(4, 4))
assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4))
assertTrue(
diagonal1Offset.source contentEquals
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
)
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
assertEquals(diagonal2.shape, ShapeND(4, 2, 4))
assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4))
assertTrue(
diagonal2.source contentEquals
doubleArrayOf(
@ -203,7 +202,7 @@ internal class TestDoubleTensorAlgebra {
val l = tensor.getTensor(0).map { it + 1.0 }
val r = tensor.getTensor(1).map { it - 1.0 }
val res = l + r
assertEquals(ShapeND(5, 5), res.shape)
assertTrue { ShapeND(5, 5) contentEquals res.shape }
assertEquals(2.0, res[4, 4])
}
}

View File

@ -67,7 +67,7 @@ public open class ViktorFieldOpsND :
right: StructureND<Double>,
transform: Float64Field.(Double, Double) -> Double,
): ViktorStructureND {
require(left.shape == right.shape)
require(left.shape.contentEquals(right.shape))
return F64Array(*left.shape.asArray()).apply {
ColumnStrides(left.shape).asSequence().forEach { index ->
set(value = Float64Field.transform(left[index], right[index]), indices = index)