forked from kscience/kmath
Compare commits
11 Commits
dev
...
tmp/generi
Author | SHA1 | Date | |
---|---|---|---|
5efd5e8304 | |||
d2c8423b6f | |||
2a2c5e8765 | |||
47a9bf0e9a | |||
b8bf56bc42 | |||
597c27e615 | |||
40ffa8d6fc | |||
1c5113da29 | |||
4f2ebb17ff | |||
8fa42450b2 | |||
e0125f0b26 |
@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.series
|
||||
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import space.kscience.kmath.operations.bufferAlgebra
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.plotly.*
|
||||
import space.kscience.plotly.models.Scatter
|
||||
import kotlin.math.sin
|
||||
|
||||
private val customAlgebra = (Double.algebra.bufferAlgebra) { SeriesAlgebra(this) { it.toDouble() } }
|
||||
|
||||
fun main(): Unit = (customAlgebra) {
|
||||
val signal = DoubleArray(800) {
|
||||
sin(it.toDouble() / 10.0) + 3.5 * sin(it.toDouble() / 60.0)
|
||||
}.asBuffer().moveTo(0)
|
||||
val emd = empiricalModeDecomposition(
|
||||
sConditionThreshold = 1,
|
||||
maxSiftIterations = 15,
|
||||
siftingDelta = 1e-2,
|
||||
nModes = 4
|
||||
).decompose(signal)
|
||||
println("EMD: ${emd.modes.size} modes extracted, terminated because ${emd.terminatedBecause}")
|
||||
|
||||
fun Plot.series(name: String, buffer: Buffer<Double>, block: Scatter.() -> Unit = {}) {
|
||||
this.scatter {
|
||||
this.name = name
|
||||
this.x.numbers = buffer.labels
|
||||
this.y.doubles = buffer.toDoubleArray()
|
||||
block()
|
||||
}
|
||||
}
|
||||
|
||||
Plotly.plot {
|
||||
series("Signal", signal)
|
||||
emd.modes.forEachIndexed { index, it ->
|
||||
series("Mode ${index+1}", it)
|
||||
}
|
||||
}.makeFile(resourceLocation = ResourceLocation.REMOTE)
|
||||
}
|
@ -35,7 +35,7 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, Float64
|
||||
@OptIn(PerformancePitfall::class)
|
||||
private val StructureND<Double>.buffer: Float64Buffer
|
||||
get() = when {
|
||||
shape != this@StreamDoubleFieldND.shape -> throw ShapeMismatchException(
|
||||
!shape.contentEquals(this@StreamDoubleFieldND.shape) -> throw ShapeMismatchException(
|
||||
this@StreamDoubleFieldND.shape,
|
||||
shape
|
||||
)
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
@ -61,7 +62,7 @@ fun main() {
|
||||
// figure out MSE of approximation
|
||||
fun mse(yTrue: DoubleTensor, yPred: DoubleTensor): Double {
|
||||
require(yTrue.shape.size == 1)
|
||||
require(yTrue.shape == yPred.shape)
|
||||
require(yTrue.shape contentEquals yPred.shape)
|
||||
|
||||
val diff = yTrue - yPred
|
||||
return sqrt(diff.dot(diff)).value()
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.asIterable
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.*
|
||||
@ -93,7 +94,7 @@ class Dense(
|
||||
|
||||
// simple accuracy equal to the proportion of correct answers
|
||||
fun accuracy(yPred: DoubleTensor, yTrue: DoubleTensor): Double {
|
||||
check(yPred.shape == yTrue.shape)
|
||||
check(yPred.shape contentEquals yTrue.shape)
|
||||
val n = yPred.shape[0]
|
||||
var correctCnt = 0
|
||||
for (i in 0 until n) {
|
||||
|
@ -32,12 +32,12 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
|
||||
}
|
||||
|
||||
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
|
@ -33,12 +33,12 @@ public object Float64LinearSpace : LinearSpace<Double, Float64Field> {
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
|
||||
*/
|
||||
@PerformancePitfall("Very slow on remote execution algebras")
|
||||
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> {
|
||||
require(left.shape == right.shape) {
|
||||
require(left.shape.contentEquals(right.shape)) {
|
||||
"Expected left and right of the same shape, but left - ${left.shape} and right - ${right.shape}"
|
||||
}
|
||||
return structureND(left.shape) { index ->
|
||||
|
@ -106,10 +106,10 @@ public class ColumnStrides(override val shape: ShapeND) : Strides() {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is ColumnStrides) return false
|
||||
return shape == other.shape
|
||||
return shape.contentEquals(other.shape)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = shape.hashCode()
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
|
||||
|
||||
public companion object
|
||||
@ -156,10 +156,10 @@ public class RowStrides(override val shape: ShapeND) : Strides() {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is RowStrides) return false
|
||||
return shape == other.shape
|
||||
return shape.contentEquals(other.shape)
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = shape.hashCode()
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
|
||||
public companion object
|
||||
|
||||
|
@ -11,30 +11,19 @@ import kotlin.jvm.JvmInline
|
||||
/**
|
||||
* A read-only ND shape
|
||||
*/
|
||||
public class ShapeND(@PublishedApi internal val array: IntArray) {
|
||||
@JvmInline
|
||||
public value class ShapeND(@PublishedApi internal val array: IntArray) {
|
||||
public val size: Int get() = array.size
|
||||
public operator fun get(index: Int): Int = array[index]
|
||||
override fun toString(): String = array.contentToString()
|
||||
override fun hashCode(): Int = array.contentHashCode()
|
||||
override fun equals(other: Any?): Boolean = other is ShapeND && array.contentEquals(other.array)
|
||||
}
|
||||
|
||||
public inline fun ShapeND.forEach(block: (value: Int) -> Unit): Unit = array.forEach(block)
|
||||
|
||||
public inline fun ShapeND.forEachIndexed(block: (index: Int, value: Int) -> Unit): Unit = array.forEachIndexed(block)
|
||||
|
||||
@Deprecated(
|
||||
message = "ShapeND is made a usual class with correct `equals`. Use it instead of the `contentEquals`.",
|
||||
replaceWith = ReplaceWith("this == other"),
|
||||
level = DeprecationLevel.WARNING,
|
||||
)
|
||||
public infix fun ShapeND.contentEquals(other: ShapeND): Boolean = array.contentEquals(other.array)
|
||||
|
||||
@Deprecated(
|
||||
message = "ShapeND is made a usual class with correct `hashCode`. Use it instead of the `contentHashCode`.",
|
||||
replaceWith = ReplaceWith("this.hashCode()"),
|
||||
level = DeprecationLevel.WARNING,
|
||||
)
|
||||
public fun ShapeND.contentHashCode(): Int = array.contentHashCode()
|
||||
|
||||
public val ShapeND.indices: IntRange get() = array.indices
|
||||
|
@ -48,12 +48,12 @@ public object Float64ParallelLinearSpace : LinearSpace<Double, Float64Field> {
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = Floa64FieldOpsND {
|
||||
require(shape == other.shape) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
|
@ -92,7 +92,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||
require(left.shape == right.shape) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||
val leftArray = left.asMultik().array
|
||||
val rightArray = right.asMultik().array
|
||||
val data = initMemoryView<T>(leftArray.size, dataType)
|
||||
@ -124,7 +124,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1)) {
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1)) {
|
||||
get(intArrayOf(0))
|
||||
} else null
|
||||
|
||||
@ -224,7 +224,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
}
|
||||
|
||||
val mt = asMultik().array
|
||||
return if (ShapeND(mt.shape) == shape) {
|
||||
return if (ShapeND(mt.shape).contentEquals(shape)) {
|
||||
mt
|
||||
} else {
|
||||
@OptIn(UnsafeKMathAPI::class)
|
||||
|
@ -63,7 +63,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
||||
right: StructureND<T>,
|
||||
transform: C.(T, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
require(left.shape == right.shape) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||
val new = Nd4j.create(*left.shape.asArray()).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
|
||||
return new
|
||||
|
@ -49,7 +49,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
|
||||
require(left.shape == right.shape)
|
||||
require(left.shape.contentEquals(right.shape))
|
||||
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, Float64Field>
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.valueOrNull(): Double? =
|
||||
if (shape == ShapeND(1)) ndArray.getDouble(0) else null
|
||||
if (shape contentEquals ShapeND(1)) ndArray.getDouble(0) else null
|
||||
|
||||
// TODO rewrite
|
||||
override fun diagonalEmbedding(
|
||||
|
@ -14,6 +14,7 @@ kotlin.sourceSets {
|
||||
dependencies {
|
||||
api(projects.kmathCoroutines)
|
||||
//implementation(spclibs.atomicfu)
|
||||
api(project(":kmath-functions"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,288 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.series
|
||||
|
||||
import space.kscience.kmath.interpolation.SplineInterpolator
|
||||
import space.kscience.kmath.interpolation.interpolate
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.last
|
||||
|
||||
/**
|
||||
* Empirical mode decomposition of a signal represented as a [Series].
|
||||
*
|
||||
* @param seriesAlgebra context to perform operations in.
|
||||
* @param maxSiftIterations number of iterations in the mode extraction (sifting) process after which
|
||||
* the result is returned even if no other conditions are met.
|
||||
* @param sConditionThreshold one of the possible criteria for sifting process termination:
|
||||
* how many times in a row should a proto-mode satisfy the s-condition (the number of zeros differs from
|
||||
* the number of extrema no more than by 1) to be considered an empirical mode.
|
||||
* @param siftingDelta one of the possible criteria for sifting process termination: if relative difference of
|
||||
* two proto-modes obtained in a sequence is less that this number, the last proto-mode is considered
|
||||
* an empirical mode and returned.
|
||||
* @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not
|
||||
* possible to extract more modes from the signal.
|
||||
*/
|
||||
public class EmpiricalModeDecomposition<T: Comparable<T>, A: Field<T>, BA, L: T> (
|
||||
private val seriesAlgebra: SeriesAlgebra<T, A, BA, L>,
|
||||
private val sConditionThreshold: Int = 15,
|
||||
private val maxSiftIterations: Int = 20,
|
||||
private val siftingDelta: T,
|
||||
private val nModes: Int = 6
|
||||
) where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> {
|
||||
|
||||
/**
|
||||
* Take a signal, construct an upper and a lower envelopes, find the mean value of two,
|
||||
* represented as a series.
|
||||
* @param signal Signal to compute on
|
||||
*
|
||||
* @return mean [Series] or `null`. `null` is returned in case
|
||||
* the signal does not have enough extrema to construct envelopes.
|
||||
*/
|
||||
private fun findMean(signal: Series<T>): Series<T>? = (seriesAlgebra) {
|
||||
val interpolator = SplineInterpolator(elementAlgebra)
|
||||
val makeBuffer = elementAlgebra.bufferFactory
|
||||
fun generateEnvelope(extrema: List<Int>, paddedExtremeValues: Buffer<T>): Series<T> {
|
||||
val envelopeFunction = interpolator.interpolate(
|
||||
makeBuffer(extrema.size) { signal.labels[extrema[it]] },
|
||||
paddedExtremeValues
|
||||
)
|
||||
return signal.mapWithLabel { _, label ->
|
||||
// For some reason PolynomialInterpolator is exclusive and the right boundary
|
||||
// TODO Notify interpolator authors
|
||||
envelopeFunction(label) ?: paddedExtremeValues.last()
|
||||
// need to make the interpolator yield values outside boundaries?
|
||||
}
|
||||
}
|
||||
// Extrema padding (experimental) TODO padding needs a dedicated function
|
||||
val maxima = listOf(0) + signal.peaks() + (signal.size - 1)
|
||||
val maxValues = makeBuffer(maxima.size) { signal[maxima[it]] }
|
||||
if (maxValues[0] < maxValues[1]) {
|
||||
maxValues[0] = maxValues[1]
|
||||
}
|
||||
if (maxValues.last() < maxValues[maxValues.size - 2]) {
|
||||
maxValues[maxValues.size - 1] = maxValues[maxValues.size - 2]
|
||||
}
|
||||
|
||||
val minima = listOf(0) + signal.troughs() + (signal.size - 1)
|
||||
val minValues = makeBuffer(minima.size) { signal[minima[it]] }
|
||||
if (minValues[0] > minValues[1]) {
|
||||
minValues[0] = minValues[1]
|
||||
}
|
||||
if (minValues.last() > minValues[minValues.size - 2]) {
|
||||
minValues[minValues.size - 1] = minValues[minValues.size - 2]
|
||||
}
|
||||
return if (maxima.size < 3 || minima.size < 3) null else { // maybe make an early return?
|
||||
val upperEnvelope = generateEnvelope(maxima, maxValues)
|
||||
val lowerEnvelope = generateEnvelope(minima, minValues)
|
||||
return (upperEnvelope + lowerEnvelope).map { it * 0.5 }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a single empirical mode from a signal. This process is called sifting, hence the name.
|
||||
* @param signal Signal to extract a mode from. The first mode is extracted from the initial signal,
|
||||
* subsequent modes are extracted from the residuals between the signal and all previous modes.
|
||||
*
|
||||
* @return [SiftingResult.NotEnoughExtrema] is returned if the signal has too few extrema to extract a mode.
|
||||
* Success of an appropriate type (See [SiftingResult.Success] class) is returned otherwise.
|
||||
*/
|
||||
private fun sift(signal: Series<T>): SiftingResult = siftInner(signal, 1, 0)
|
||||
|
||||
/**
|
||||
* Compute a single iteration of the sifting process.
|
||||
*/
|
||||
private tailrec fun siftInner(
|
||||
prevMode: Series<T>,
|
||||
iterationNumber: Int,
|
||||
sNumber: Int
|
||||
): SiftingResult = (seriesAlgebra) {
|
||||
val mean = findMean(prevMode) ?:
|
||||
return if (iterationNumber == 1) SiftingResult.NotEnoughExtrema
|
||||
else SiftingResult.SignalFlattened(prevMode)
|
||||
val mode = prevMode.zip(mean) { p, m -> p - m }
|
||||
val newSNumber = if (sCondition(mode)) sNumber + 1 else sNumber
|
||||
return when {
|
||||
iterationNumber >= maxSiftIterations -> SiftingResult.MaxIterationsReached(mode)
|
||||
sNumber >= sConditionThreshold -> SiftingResult.SNumberReached(mode)
|
||||
relativeDifference(mode, prevMode) < (elementAlgebra) { siftingDelta * mode.size } ->
|
||||
SiftingResult.DeltaReached(mode)
|
||||
else -> siftInner(mode, iterationNumber + 1, newSNumber)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract [nModes] empirical modes from a signal represented by a time series.
|
||||
* @param signal Signal to decompose.
|
||||
* @return [EMDecompositionResult] with an appropriate reason for algorithm termination
|
||||
* (see [EMDTerminationReason] for possible reasons).
|
||||
* Modes returned in a list which contains as many modes as it was possible
|
||||
* to extract before triggering one of the termination conditions.
|
||||
*/
|
||||
public fun decompose(signal: Series<T>): EMDecompositionResult<T> = (seriesAlgebra) {
|
||||
val modes = mutableListOf<Series<T>>()
|
||||
var residual = signal
|
||||
repeat(nModes) {
|
||||
val nextMode = when(val r = sift(residual)) {
|
||||
SiftingResult.NotEnoughExtrema ->
|
||||
return EMDecompositionResult(
|
||||
if (it == 0) EMDTerminationReason.SIGNAL_TOO_FLAT
|
||||
else EMDTerminationReason.ALL_POSSIBLE_MODES_EXTRACTED,
|
||||
modes,
|
||||
residual
|
||||
)
|
||||
is SiftingResult.Success<*> -> r.result
|
||||
}
|
||||
modes.add(nextMode as Series<T>) // TODO remove unchecked cast
|
||||
residual = residual.zip(nextMode) { l, r -> l - r }
|
||||
}
|
||||
return EMDecompositionResult(EMDTerminationReason.MAX_MODES_REACHED, modes, residual)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut to retrieve a decomposition factory from a [SeriesAlgebra] scope.
|
||||
* @receiver scope to perform operations in.
|
||||
* @param maxSiftIterations number of iterations in the mode extraction (sifting) process after which
|
||||
* the result is returned even if no other conditions are met.
|
||||
* @param sConditionThreshold one of the possible criteria for sifting process termination:
|
||||
* how many times in a row should a proto-mode satisfy the s-condition (the number of zeros differs from
|
||||
* the number of extrema no more than by 1) to be considered an empirical mode.
|
||||
* @param siftingDelta one of the possible criteria for sifting process termination: if relative difference of
|
||||
* two proto-modes obtained in a sequence is less that this number, the last proto-mode is considered
|
||||
* an empirical mode and returned.
|
||||
* @param nModes how many modes should be extracted at most. The algorithm may return fewer modes if it was not
|
||||
* possible to extract more modes from the signal.
|
||||
*/
|
||||
public fun <T: Comparable<T>, L: T, A: Field<T>, BA> SeriesAlgebra<T, A, BA, L>.empiricalModeDecomposition(
|
||||
sConditionThreshold: Int = 15,
|
||||
maxSiftIterations: Int = 20,
|
||||
siftingDelta: T,
|
||||
nModes: Int = 3
|
||||
): EmpiricalModeDecomposition<T, A, BA, L>
|
||||
where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> = EmpiricalModeDecomposition(
|
||||
seriesAlgebra = this,
|
||||
sConditionThreshold = sConditionThreshold,
|
||||
maxSiftIterations = maxSiftIterations,
|
||||
siftingDelta = siftingDelta,
|
||||
nModes = nModes
|
||||
)
|
||||
|
||||
/**
|
||||
* Brute force count all zeros in the series.
|
||||
*/
|
||||
private fun <T: Comparable<T>, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.countZeros(
|
||||
signal: Series<T>
|
||||
): Int where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> {
|
||||
require(signal.size >= 2) { "Expected series with at least 2 elements, but got ${signal.size} elements" }
|
||||
data class SignCounter(val prevSign: Int, val zeroCount: Int)
|
||||
fun strictSign(arg: T): Int = if (arg > elementAlgebra.zero) 1 else -1
|
||||
|
||||
return signal.fold(SignCounter(strictSign(signal[0]), 0)) { acc, it ->
|
||||
val currentSign = strictSign(it)
|
||||
if (acc.prevSign != currentSign) SignCounter(currentSign, acc.zeroCount + 1)
|
||||
else SignCounter(currentSign, acc.zeroCount)
|
||||
}.zeroCount
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute relative difference of two series.
|
||||
*/
|
||||
private fun <T, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.relativeDifference(
|
||||
current: Series<T>,
|
||||
previous: Series<T>
|
||||
): T where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> = (bufferAlgebra) {
|
||||
((current - previous) * (current - previous))
|
||||
.div(previous * previous)
|
||||
.fold(elementAlgebra.zero) { acc, it -> acc + it}
|
||||
}
|
||||
|
||||
/**
|
||||
* Brute force count all extrema of a series.
|
||||
*/
|
||||
private fun <T: Comparable<T>> Series<T>.countExtrema(): Int {
|
||||
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
||||
return peaks().size + troughs().size
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the numbers of zeroes and extrema of a series differ by no more than 1.
|
||||
* This is a necessary condition of an empirical mode.
|
||||
*/
|
||||
private fun <T: Comparable<T>, A: Ring<T>, BA> SeriesAlgebra<T, A, BA, *>.sCondition(
|
||||
signal: Series<T>
|
||||
): Boolean where BA: BufferAlgebra<T, A>, BA: FieldOps<Buffer<T>> =
|
||||
(signal.countExtrema() - countZeros(signal)) in -1..1
|
||||
|
||||
internal sealed interface SiftingResult {
|
||||
|
||||
/**
|
||||
* Represents a condition when a mode has been successfully
|
||||
* extracted in a sifting process.
|
||||
*/
|
||||
open class Success<T>(val result: Series<T>): SiftingResult
|
||||
|
||||
/**
|
||||
* Returned when no termination condition was reached and the proto-mode
|
||||
* has become too flat (with not enough extrema to build envelopes)
|
||||
* after several sifting iterations.
|
||||
*/
|
||||
class SignalFlattened<T>(result: Series<T>) : Success<T>(result)
|
||||
|
||||
/**
|
||||
* Returned when sifting process has been terminated due to the
|
||||
* S-number condition being reached.
|
||||
*/
|
||||
class SNumberReached<T>(result: Series<T>) : Success<T>(result)
|
||||
|
||||
/**
|
||||
* Returned when sifting process has been terminated due to the
|
||||
* delta condition (Cauchy criterion) being reached.
|
||||
*/
|
||||
class DeltaReached<T>(result: Series<T>) : Success<T>(result)
|
||||
|
||||
/**
|
||||
* Returned when sifting process has been terminated after
|
||||
* executing the maximum number of iterations (specified when creating an instance
|
||||
* of [EmpiricalModeDecomposition]).
|
||||
*/
|
||||
class MaxIterationsReached<T>(result: Series<T>): Success<T>(result)
|
||||
|
||||
/**
|
||||
* Returned when the submitted signal has not enough extrema to build envelopes,
|
||||
* i.e. when [SignalFlattened] condition has already been reached before the first sifting iteration.
|
||||
*/
|
||||
data object NotEnoughExtrema : SiftingResult
|
||||
|
||||
}
|
||||
|
||||
public enum class EMDTerminationReason {
|
||||
/**
|
||||
* Returned when the signal is too flat, i.e. there are too few extrema
|
||||
* to build envelopes necessary to extract modes.
|
||||
*/
|
||||
SIGNAL_TOO_FLAT,
|
||||
|
||||
/**
|
||||
* Returned when there has been extracted as many modes as
|
||||
* specified when creating the instance of [EmpiricalModeDecomposition]
|
||||
*/
|
||||
MAX_MODES_REACHED,
|
||||
|
||||
/**
|
||||
* Returned when the algorithm terminated after finding impossible
|
||||
* to extract more modes from the signal and the maximum number
|
||||
* of modes (specified when creating an instance of [EmpiricalModeDecomposition])
|
||||
* has not yet been reached.
|
||||
*/
|
||||
ALL_POSSIBLE_MODES_EXTRACTED
|
||||
}
|
||||
|
||||
public data class EMDecompositionResult<T>(
|
||||
val terminatedBecause: EMDTerminationReason,
|
||||
val modes: List<Series<T>>,
|
||||
val residual: Series<T>
|
||||
)
|
@ -169,7 +169,7 @@ public open class SeriesAlgebra<T, out A : Ring<T>, out BA : BufferAlgebra<T, A>
|
||||
public inline fun Buffer<T>.mapWithLabel(crossinline transform: A.(arg: T, label: L) -> T): Series<T> {
|
||||
val labels = labels
|
||||
val buf = elementAlgebra.bufferFactory(size) {
|
||||
elementAlgebra.transform(getByOffset(it), labels[it])
|
||||
elementAlgebra.transform(get(it), labels[it])
|
||||
}
|
||||
return buf.moveTo(offsetIndices.first)
|
||||
}
|
||||
|
@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.series
|
||||
|
||||
|
||||
public enum class PlateauEdgePolicy {
|
||||
/**
|
||||
* Midpoints of plateau are returned, edges not belonging to a plateau are ignored.
|
||||
*
|
||||
* A midpoint is the index closest to the average of indices of the left and right edges
|
||||
* of the plateau:
|
||||
*
|
||||
* `val midpoint = ((leftEdge + rightEdge) / 2).toInt`
|
||||
*/
|
||||
AVERAGE,
|
||||
|
||||
/**
|
||||
* Both left and right edges are returned.
|
||||
*/
|
||||
KEEP_ALL_EDGES,
|
||||
|
||||
/**
|
||||
* Only right edges are returned.
|
||||
*/
|
||||
KEEP_RIGHT_EDGES,
|
||||
|
||||
/**
|
||||
* Only left edges are returned.
|
||||
*/
|
||||
KEEP_LEFT_EDGES,
|
||||
|
||||
/**
|
||||
* Ignore plateau, only peaks (troughs) with values strictly greater (less)
|
||||
* than values of the adjacent points are returned.
|
||||
*/
|
||||
IGNORE
|
||||
}
|
||||
|
||||
|
||||
public fun <T: Comparable<T>> Series<T>.peaks(
|
||||
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
|
||||
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this > other }, { other -> this >= other })
|
||||
|
||||
public fun <T: Comparable<T>> Series<T>.troughs(
|
||||
plateauEdgePolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE
|
||||
): List<Int> = findPeaks(plateauEdgePolicy, { other -> this < other }, { other -> this <= other })
|
||||
|
||||
|
||||
private fun <T: Comparable<T>> Series<T>.findPeaks(
|
||||
plateauPolicy: PlateauEdgePolicy = PlateauEdgePolicy.AVERAGE,
|
||||
cmpStrong: T.(T) -> Boolean,
|
||||
cmpWeak: T.(T) -> Boolean
|
||||
): List<Int> {
|
||||
require(size >= 3) { "Expected series with at least 3 elements, but got $size elements" }
|
||||
if (plateauPolicy == PlateauEdgePolicy.AVERAGE) return peaksWithPlateau(cmpStrong)
|
||||
fun peakCriterion(left: T, middle: T, right: T): Boolean = when(plateauPolicy) {
|
||||
PlateauEdgePolicy.KEEP_LEFT_EDGES -> middle.cmpStrong(left) && middle.cmpWeak(right)
|
||||
PlateauEdgePolicy.KEEP_RIGHT_EDGES -> middle.cmpWeak(left) && middle.cmpStrong(right)
|
||||
PlateauEdgePolicy.KEEP_ALL_EDGES ->
|
||||
(middle.cmpStrong(left) && middle.cmpWeak(right)) || (middle.cmpWeak(left) && middle.cmpStrong(right))
|
||||
else -> middle.cmpStrong(right) && middle.cmpStrong(left)
|
||||
}
|
||||
val indices = mutableListOf<Int>()
|
||||
for (index in 1 .. size - 2) {
|
||||
val left = this[index - 1]
|
||||
val middle = this[index]
|
||||
val right = this[index + 1]
|
||||
if (peakCriterion(left, middle, right)) indices.add(index)
|
||||
}
|
||||
return indices
|
||||
}
|
||||
|
||||
|
||||
private fun <T: Comparable<T>> Series<T>.peaksWithPlateau(cmpStrong: T.(T) -> Boolean): List<Int> {
|
||||
val peaks = mutableListOf<Int>()
|
||||
tailrec fun peaksPlateauInner(index: Int) {
|
||||
val nextUnequal = (index + 1 ..< size).firstOrNull { this[it] != this[index] } ?: (size - 1)
|
||||
val newIndex = if (this[index].cmpStrong(this[index - 1]) && this[index].cmpStrong(this[nextUnequal])) {
|
||||
peaks.add((index + nextUnequal) / 2)
|
||||
nextUnequal
|
||||
} else index + 1
|
||||
if (newIndex < size - 1) peaksPlateauInner(newIndex)
|
||||
}
|
||||
peaksPlateauInner(1)
|
||||
return peaks
|
||||
}
|
@ -23,6 +23,7 @@ import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.asArray
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
@ -104,7 +105,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
||||
protected abstract fun const(value: T): Constant<TT>
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape == ShapeND(1))
|
||||
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals ShapeND(1))
|
||||
get(intArrayOf(0)) else null
|
||||
|
||||
/**
|
||||
|
@ -95,7 +95,7 @@ public open class DoubleTensorAlgebra :
|
||||
|
||||
override fun StructureND<Double>.valueOrNull(): Double? {
|
||||
val dt = asDoubleTensor()
|
||||
return if (dt.shape == ShapeND(1)) dt.source[0] else null
|
||||
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.value(): Double = valueOrNull()
|
||||
|
@ -89,7 +89,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
|
||||
|
||||
override fun StructureND<Int>.valueOrNull(): Int? {
|
||||
val dt = asIntTensor()
|
||||
return if (dt.shape == ShapeND(1)) dt.source[0] else null
|
||||
return if (dt.shape contentEquals ShapeND(1)) dt.source[0] else null
|
||||
}
|
||||
|
||||
override fun StructureND<Int>.value(): Int = valueOrNull()
|
||||
@ -387,7 +387,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, Int32Ring> {
|
||||
public fun stack(tensors: List<Tensor<Int>>): IntTensor {
|
||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||
val shape = tensors[0].shape
|
||||
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
|
||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
||||
val resShape = ShapeND(tensors.size) + shape
|
||||
// val resBuffer: List<Int> = tensors.flatMap {
|
||||
// it.asIntTensor().source.array.drop(it.asIntTensor().bufferStart)
|
||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.nd.linearSize
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
@ -31,7 +32,7 @@ internal fun checkBufferShapeConsistency(shape: ShapeND, buffer: DoubleArray) =
|
||||
|
||||
@PublishedApi
|
||||
internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>): Unit =
|
||||
check(a.shape == b.shape) {
|
||||
check(a.shape contentEquals b.shape) {
|
||||
"Incompatible shapes ${a.shape} and ${b.shape} "
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ public fun DoubleTensorAlgebra.randomNormalLike(structure: WithShape, seed: Long
|
||||
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
|
||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||
val shape = tensors[0].shape
|
||||
check(tensors.all { it.shape == shape }) { "Tensors must have same shapes" }
|
||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
||||
val resShape = ShapeND(tensors.size) + shape
|
||||
// val resBuffer: List<Double> = tensors.flatMap {
|
||||
// it.asDoubleTensor().source.array.drop(it.asDoubleTensor().bufferStart)
|
||||
@ -91,7 +91,7 @@ public fun DoubleTensorAlgebra.luPivot(
|
||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(luTensor.shape)
|
||||
check(
|
||||
luTensor.shape.first(luTensor.shape.size - 2) == pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
|
||||
luTensor.shape.first(luTensor.shape.size - 2) contentEquals pivotsTensor.shape.first(pivotsTensor.shape.size - 1) ||
|
||||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
||||
) { "Inappropriate shapes of input tensors" }
|
||||
|
||||
|
@ -6,31 +6,29 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastShapes
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastTensors
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastTo
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class TestBroadcasting {
|
||||
|
||||
@Test
|
||||
fun testBroadcastShapes() = DoubleTensorAlgebra {
|
||||
assertEquals(
|
||||
assertTrue(
|
||||
broadcastShapes(
|
||||
listOf(ShapeND(2, 3), ShapeND(1, 3), ShapeND(1, 1, 1))
|
||||
),
|
||||
ShapeND(1, 2, 3)
|
||||
) contentEquals ShapeND(1, 2, 3)
|
||||
)
|
||||
|
||||
assertEquals(
|
||||
assertTrue(
|
||||
broadcastShapes(
|
||||
listOf(ShapeND(6, 7), ShapeND(5, 6, 1), ShapeND(7), ShapeND(5, 1, 7))
|
||||
),
|
||||
ShapeND(5, 6, 7)
|
||||
) contentEquals ShapeND(5, 6, 7)
|
||||
)
|
||||
}
|
||||
|
||||
@ -40,7 +38,7 @@ internal class TestBroadcasting {
|
||||
val tensor2 = fromArray(ShapeND(1, 3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
|
||||
val res = broadcastTo(tensor2, tensor1.shape)
|
||||
assertTrue(res.shape == ShapeND(2, 3))
|
||||
assertTrue(res.shape contentEquals ShapeND(2, 3))
|
||||
assertTrue(res.source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
}
|
||||
|
||||
@ -52,9 +50,9 @@ internal class TestBroadcasting {
|
||||
|
||||
val res = broadcastTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertEquals(res[0].shape, ShapeND(1, 2, 3))
|
||||
assertEquals(res[1].shape, ShapeND(1, 2, 3))
|
||||
assertEquals(res[2].shape, ShapeND(1, 2, 3))
|
||||
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(res[2].shape contentEquals ShapeND(1, 2, 3))
|
||||
|
||||
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
@ -69,9 +67,9 @@ internal class TestBroadcasting {
|
||||
|
||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertEquals(res[0].shape, ShapeND(1, 2, 3))
|
||||
assertEquals(res[1].shape, ShapeND(1, 1, 3))
|
||||
assertEquals(res[2].shape, ShapeND(1, 1, 1))
|
||||
assertTrue(res[0].shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals ShapeND(1, 1, 3))
|
||||
assertTrue(res[2].shape contentEquals ShapeND(1, 1, 1))
|
||||
|
||||
assertTrue(res[0].source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].source contentEquals doubleArrayOf(10.0, 20.0, 30.0))
|
||||
@ -86,9 +84,9 @@ internal class TestBroadcasting {
|
||||
|
||||
val res = broadcastOuterTensors(tensor1, tensor2, tensor3)
|
||||
|
||||
assertEquals(res[0].shape, ShapeND(4, 2, 5, 3, 2, 3))
|
||||
assertEquals(res[1].shape, ShapeND(4, 2, 5, 3, 3, 3))
|
||||
assertEquals(res[2].shape, ShapeND(4, 2, 5, 3, 1, 1))
|
||||
assertTrue(res[0].shape contentEquals ShapeND(4, 2, 5, 3, 2, 3))
|
||||
assertTrue(res[1].shape contentEquals ShapeND(4, 2, 5, 3, 3, 3))
|
||||
assertTrue(res[2].shape contentEquals ShapeND(4, 2, 5, 3, 1, 1))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -101,16 +99,16 @@ internal class TestBroadcasting {
|
||||
val tensor31 = tensor3 - tensor1
|
||||
val tensor32 = tensor3 - tensor2
|
||||
|
||||
assertEquals(tensor21.shape, ShapeND(2, 3))
|
||||
assertTrue(tensor21.shape contentEquals ShapeND(2, 3))
|
||||
assertTrue(tensor21.source contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
||||
assertEquals(tensor31.shape, ShapeND(1, 2, 3))
|
||||
assertTrue(tensor31.shape contentEquals ShapeND(1, 2, 3))
|
||||
assertTrue(
|
||||
tensor31.source
|
||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)
|
||||
)
|
||||
|
||||
assertEquals(tensor32.shape, ShapeND(1, 1, 3))
|
||||
assertTrue(tensor32.shape contentEquals ShapeND(1, 1, 3))
|
||||
assertTrue(tensor32.source contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.contentEquals
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.tensors.core.internal.svd1d
|
||||
import kotlin.math.abs
|
||||
@ -112,8 +113,8 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
val (q, r) = qr(tensor)
|
||||
|
||||
assertEquals(q.shape, shape)
|
||||
assertEquals(r.shape, shape)
|
||||
assertTrue { q.shape contentEquals shape }
|
||||
assertTrue { r.shape contentEquals shape }
|
||||
|
||||
assertTrue((q matmul r).eq(tensor))
|
||||
|
||||
@ -132,9 +133,9 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
val (p, l, u) = lu(tensor)
|
||||
|
||||
assertEquals(p.shape, shape)
|
||||
assertEquals(l.shape, shape)
|
||||
assertEquals(u.shape, shape)
|
||||
assertTrue { p.shape contentEquals shape }
|
||||
assertTrue { l.shape contentEquals shape }
|
||||
assertTrue { u.shape contentEquals shape }
|
||||
|
||||
assertTrue((p matmul tensor).eq(l matmul u))
|
||||
}
|
||||
@ -156,7 +157,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
val res = svd1d(tensor2)
|
||||
|
||||
assertEquals(res.shape, ShapeND(2))
|
||||
assertTrue(res.shape contentEquals ShapeND(2))
|
||||
assertTrue { abs(abs(res.source[0]) - 0.386) < 0.01 }
|
||||
assertTrue { abs(abs(res.source[1]) - 0.922) < 0.01 }
|
||||
}
|
||||
|
@ -6,8 +6,7 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.get
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.testutils.assertBufferEquals
|
||||
import kotlin.test.Test
|
||||
@ -44,7 +43,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res = tensor.transposed(0, 0)
|
||||
|
||||
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(0.0))
|
||||
assertEquals(res.shape, ShapeND(1))
|
||||
assertTrue(res.shape contentEquals ShapeND(1))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -53,7 +52,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res = tensor.transposed(1, 0)
|
||||
|
||||
assertTrue(res.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertEquals(res.shape, ShapeND(2, 3))
|
||||
assertTrue(res.shape contentEquals ShapeND(2, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -63,9 +62,9 @@ internal class TestDoubleTensorAlgebra {
|
||||
val res02 = tensor.transposed(-3, 2)
|
||||
val res12 = tensor.transposed()
|
||||
|
||||
assertEquals(res01.shape, ShapeND(2, 1, 3))
|
||||
assertEquals(res02.shape, ShapeND(3, 2, 1))
|
||||
assertEquals(res12.shape, ShapeND(1, 3, 2))
|
||||
assertTrue(res01.shape contentEquals ShapeND(2, 1, 3))
|
||||
assertTrue(res02.shape contentEquals ShapeND(3, 2, 1))
|
||||
assertTrue(res12.shape contentEquals ShapeND(1, 3, 2))
|
||||
|
||||
assertTrue(res01.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.asDoubleTensor().source contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
@ -115,19 +114,19 @@ internal class TestDoubleTensorAlgebra {
|
||||
|
||||
val res12 = tensor1.dot(tensor2)
|
||||
assertTrue(res12.source contentEquals doubleArrayOf(140.0, 320.0))
|
||||
assertEquals(res12.shape, ShapeND(2))
|
||||
assertTrue(res12.shape contentEquals ShapeND(2))
|
||||
|
||||
val res32 = tensor3.matmul(tensor2)
|
||||
assertTrue(res32.source contentEquals doubleArrayOf(-140.0))
|
||||
assertEquals(res32.shape, ShapeND(1, 1))
|
||||
assertTrue(res32.shape contentEquals ShapeND(1, 1))
|
||||
|
||||
val res22 = tensor2.dot(tensor2)
|
||||
assertTrue(res22.source contentEquals doubleArrayOf(1400.0))
|
||||
assertEquals(res22.shape, ShapeND(1))
|
||||
assertTrue(res22.shape contentEquals ShapeND(1))
|
||||
|
||||
val res11 = tensor1.dot(tensor11)
|
||||
assertTrue(res11.source contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||
assertEquals(res11.shape, ShapeND(2, 2))
|
||||
assertTrue(res11.shape contentEquals ShapeND(2, 2))
|
||||
|
||||
val res45 = tensor4.matmul(tensor5)
|
||||
assertTrue(
|
||||
@ -136,7 +135,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||
)
|
||||
)
|
||||
assertEquals(res45.shape, ShapeND(2, 3, 3))
|
||||
assertTrue(res45.shape contentEquals ShapeND(2, 3, 3))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -145,35 +144,35 @@ internal class TestDoubleTensorAlgebra {
|
||||
val tensor2 = fromArray(ShapeND(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor3 = zeros(ShapeND(2, 3, 4, 5))
|
||||
|
||||
assertEquals(
|
||||
diagonalEmbedding(tensor3, 0, 3, 4).shape,
|
||||
assertTrue(
|
||||
diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
|
||||
ShapeND(2, 3, 4, 5, 5)
|
||||
)
|
||||
assertEquals(
|
||||
diagonalEmbedding(tensor3, 1, 3, 4).shape,
|
||||
assertTrue(
|
||||
diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
|
||||
ShapeND(2, 3, 4, 6, 6)
|
||||
)
|
||||
assertEquals(
|
||||
diagonalEmbedding(tensor3, 2, 0, 3).shape,
|
||||
assertTrue(
|
||||
diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
|
||||
ShapeND(7, 2, 3, 7, 4)
|
||||
)
|
||||
|
||||
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
|
||||
assertEquals(diagonal1.shape, ShapeND(3, 3))
|
||||
assertTrue(diagonal1.shape contentEquals ShapeND(3, 3))
|
||||
assertTrue(
|
||||
diagonal1.source contentEquals
|
||||
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0)
|
||||
)
|
||||
|
||||
val diagonal1Offset = diagonalEmbedding(tensor1, 1, 1, 0)
|
||||
assertEquals(diagonal1Offset.shape, ShapeND(4, 4))
|
||||
assertTrue(diagonal1Offset.shape contentEquals ShapeND(4, 4))
|
||||
assertTrue(
|
||||
diagonal1Offset.source contentEquals
|
||||
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0)
|
||||
)
|
||||
|
||||
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
|
||||
assertEquals(diagonal2.shape, ShapeND(4, 2, 4))
|
||||
assertTrue(diagonal2.shape contentEquals ShapeND(4, 2, 4))
|
||||
assertTrue(
|
||||
diagonal2.source contentEquals
|
||||
doubleArrayOf(
|
||||
@ -203,7 +202,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
val l = tensor.getTensor(0).map { it + 1.0 }
|
||||
val r = tensor.getTensor(1).map { it - 1.0 }
|
||||
val res = l + r
|
||||
assertEquals(ShapeND(5, 5), res.shape)
|
||||
assertTrue { ShapeND(5, 5) contentEquals res.shape }
|
||||
assertEquals(2.0, res[4, 4])
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ public open class ViktorFieldOpsND :
|
||||
right: StructureND<Double>,
|
||||
transform: Float64Field.(Double, Double) -> Double,
|
||||
): ViktorStructureND {
|
||||
require(left.shape == right.shape)
|
||||
require(left.shape.contentEquals(right.shape))
|
||||
return F64Array(*left.shape.asArray()).apply {
|
||||
ColumnStrides(left.shape).asSequence().forEach { index ->
|
||||
set(value = Float64Field.transform(left[index], right[index]), indices = index)
|
||||
|
Loading…
Reference in New Issue
Block a user