diff --git a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/DynamicTimeWarping.kt b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/DynamicTimeWarping.kt index 387ad942b..55d583fc0 100644 --- a/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/DynamicTimeWarping.kt +++ b/kmath-stat/src/commonMain/kotlin/space/kscience/kmath/series/DynamicTimeWarping.kt @@ -9,23 +9,27 @@ import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.nd.* import space.kscience.kmath.nd.DoubleBufferND import space.kscience.kmath.nd.ShapeND +import space.kscience.kmath.structures.DoubleBuffer import kotlin.math.abs /** * Offset constants which will be used later. Added them for avoiding "magical numbers" problem. */ -internal const val LEFT_OFFSET : Int = -1 -internal const val BOTTOM_OFFSET : Int = 1 -internal const val DIAGONAL_OFFSET : Int = 0 - +internal enum class DtwOffset { + LEFT, + BOTTOM, + DIAGONAL +} /** * Public class to store result of method. Class contains total penalty cost for series alignment. * Also, this class contains align matrix (which point of the first series matches to point of the other series). */ -public data class DynamicTimeWarpingData(val totalCost : Double = 0.0, - val alignMatrix : DoubleBufferND = DoubleFieldOpsND.structureND(ShapeND(0, 0)) { (_, _) -> 0.0}) +public data class DynamicTimeWarpingData( + val totalCost : Double = 0.0, + val alignMatrix : DoubleBufferND = DoubleFieldOpsND.structureND(ShapeND(0, 0)) { (_, _) -> 0.0} +) /** * PathIndices class for better code perceptibility. @@ -33,15 +37,14 @@ public data class DynamicTimeWarpingData(val totalCost : Double = 0.0, * is flags for bottom, diagonal or left offsets respectively. */ internal data class PathIndices (var id_x: Int, var id_y: Int) { - fun moveOption (direction: Int) { + fun moveOption (direction: DtwOffset) { when(direction) { - BOTTOM_OFFSET -> id_x-- - DIAGONAL_OFFSET -> { + DtwOffset.BOTTOM -> id_x-- + DtwOffset.DIAGONAL -> { id_x-- id_y-- } - LEFT_OFFSET -> id_y-- - else -> throw Exception("There is no such offset flag!") + DtwOffset.LEFT -> id_y-- } } } @@ -51,7 +54,7 @@ internal data class PathIndices (var id_x: Int, var id_y: Int) { * for two series comparing and penalty for this alignment. */ @OptIn(PerformancePitfall::class) -public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series, series2 : Series) : DynamicTimeWarpingData { +public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : DoubleBuffer, series2 : DoubleBuffer) : DynamicTimeWarpingData { var cost = 0.0 var pathLength = 0 // Special matrix of costs alignment for two series. @@ -76,7 +79,7 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series, series2 // alignMatrix contains non-zero values at position where two points from series matches // Values are penalty for concatenation of current points. val alignMatrix = structureND(ShapeND(series1.size, series2.size)) {(_, _) -> 0.0} - val indexes = PathIndices(alignMatrix.indices.shape.first() - 1, alignMatrix.indices.shape.last() - 1) + val indexes = PathIndices(series1.size - 1, series2.size - 1) with(indexes) { alignMatrix[id_x, id_y] = costMatrix[id_x, id_y] @@ -85,13 +88,13 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series, series2 while (id_x != 0 || id_y != 0) { when { id_x == 0 || costMatrix[id_x, id_y] == costMatrix[id_x, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> { - moveOption(LEFT_OFFSET) + moveOption(DtwOffset.LEFT) } id_y == 0 || costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y] + abs(series1[id_x] - series2[id_y]) -> { - moveOption(BOTTOM_OFFSET) + moveOption(DtwOffset.BOTTOM) } costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> { - moveOption(DIAGONAL_OFFSET) + moveOption(DtwOffset.DIAGONAL) } } alignMatrix[id_x, id_y] = costMatrix[id_x, id_y] @@ -103,3 +106,4 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series, series2 return DynamicTimeWarpingData(cost, alignMatrix) } + diff --git a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/series/DTWTest.kt b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/series/DTWTest.kt index b97dd1df9..292fa5da2 100644 --- a/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/series/DTWTest.kt +++ b/kmath-stat/src/commonTest/kotlin/space/kscience/kmath/series/DTWTest.kt @@ -20,8 +20,8 @@ class DTWTest { val firstSequence: DoubleArray = doubleArrayOf(0.0, 2.0, 3.0, 1.0, 3.0, 0.1, 0.0, 1.0) val secondSequence: DoubleArray = doubleArrayOf(1.0, 0.0, 3.0, 0.0, 0.0, 3.0, 2.0, 0.0, 2.0) - val seriesOne: Series = firstSequence.asBuffer().moveTo(0) - val seriesTwo: Series = secondSequence.asBuffer().moveTo(0) + val seriesOne = firstSequence.asBuffer() + val seriesTwo = secondSequence.asBuffer() val result = DoubleFieldOpsND.dynamicTimeWarping(seriesOne, seriesTwo) println("Total penalty coefficient: ${result.totalCost}")