Enum class and some code corrections.

This commit is contained in:
E––jenY-Poltavchiny 2023-06-14 04:13:39 +03:00
parent 11cbf7cdc7
commit d5c3aa563e
2 changed files with 22 additions and 18 deletions

View File

@ -9,23 +9,27 @@ import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.DoubleBufferND import space.kscience.kmath.nd.DoubleBufferND
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.abs import kotlin.math.abs
/** /**
* Offset constants which will be used later. Added them for avoiding "magical numbers" problem. * Offset constants which will be used later. Added them for avoiding "magical numbers" problem.
*/ */
internal const val LEFT_OFFSET : Int = -1 internal enum class DtwOffset {
internal const val BOTTOM_OFFSET : Int = 1 LEFT,
internal const val DIAGONAL_OFFSET : Int = 0 BOTTOM,
DIAGONAL
}
/** /**
* Public class to store result of method. Class contains total penalty cost for series alignment. * Public class to store result of method. Class contains total penalty cost for series alignment.
* Also, this class contains align matrix (which point of the first series matches to point of the other series). * Also, this class contains align matrix (which point of the first series matches to point of the other series).
*/ */
public data class DynamicTimeWarpingData(val totalCost : Double = 0.0, public data class DynamicTimeWarpingData(
val alignMatrix : DoubleBufferND = DoubleFieldOpsND.structureND(ShapeND(0, 0)) { (_, _) -> 0.0}) val totalCost : Double = 0.0,
val alignMatrix : DoubleBufferND = DoubleFieldOpsND.structureND(ShapeND(0, 0)) { (_, _) -> 0.0}
)
/** /**
* PathIndices class for better code perceptibility. * PathIndices class for better code perceptibility.
@ -33,15 +37,14 @@ public data class DynamicTimeWarpingData(val totalCost : Double = 0.0,
* is flags for bottom, diagonal or left offsets respectively. * is flags for bottom, diagonal or left offsets respectively.
*/ */
internal data class PathIndices (var id_x: Int, var id_y: Int) { internal data class PathIndices (var id_x: Int, var id_y: Int) {
fun moveOption (direction: Int) { fun moveOption (direction: DtwOffset) {
when(direction) { when(direction) {
BOTTOM_OFFSET -> id_x-- DtwOffset.BOTTOM -> id_x--
DIAGONAL_OFFSET -> { DtwOffset.DIAGONAL -> {
id_x-- id_x--
id_y-- id_y--
} }
LEFT_OFFSET -> id_y-- DtwOffset.LEFT -> id_y--
else -> throw Exception("There is no such offset flag!")
} }
} }
} }
@ -51,7 +54,7 @@ internal data class PathIndices (var id_x: Int, var id_y: Int) {
* for two series comparing and penalty for this alignment. * for two series comparing and penalty for this alignment.
*/ */
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2 : Series<Double>) : DynamicTimeWarpingData { public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : DoubleBuffer, series2 : DoubleBuffer) : DynamicTimeWarpingData {
var cost = 0.0 var cost = 0.0
var pathLength = 0 var pathLength = 0
// Special matrix of costs alignment for two series. // Special matrix of costs alignment for two series.
@ -76,7 +79,7 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2
// alignMatrix contains non-zero values at position where two points from series matches // alignMatrix contains non-zero values at position where two points from series matches
// Values are penalty for concatenation of current points. // Values are penalty for concatenation of current points.
val alignMatrix = structureND(ShapeND(series1.size, series2.size)) {(_, _) -> 0.0} val alignMatrix = structureND(ShapeND(series1.size, series2.size)) {(_, _) -> 0.0}
val indexes = PathIndices(alignMatrix.indices.shape.first() - 1, alignMatrix.indices.shape.last() - 1) val indexes = PathIndices(series1.size - 1, series2.size - 1)
with(indexes) { with(indexes) {
alignMatrix[id_x, id_y] = costMatrix[id_x, id_y] alignMatrix[id_x, id_y] = costMatrix[id_x, id_y]
@ -85,13 +88,13 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2
while (id_x != 0 || id_y != 0) { while (id_x != 0 || id_y != 0) {
when { when {
id_x == 0 || costMatrix[id_x, id_y] == costMatrix[id_x, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> { id_x == 0 || costMatrix[id_x, id_y] == costMatrix[id_x, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> {
moveOption(LEFT_OFFSET) moveOption(DtwOffset.LEFT)
} }
id_y == 0 || costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y] + abs(series1[id_x] - series2[id_y]) -> { id_y == 0 || costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y] + abs(series1[id_x] - series2[id_y]) -> {
moveOption(BOTTOM_OFFSET) moveOption(DtwOffset.BOTTOM)
} }
costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> { costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> {
moveOption(DIAGONAL_OFFSET) moveOption(DtwOffset.DIAGONAL)
} }
} }
alignMatrix[id_x, id_y] = costMatrix[id_x, id_y] alignMatrix[id_x, id_y] = costMatrix[id_x, id_y]
@ -103,3 +106,4 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2
return DynamicTimeWarpingData(cost, alignMatrix) return DynamicTimeWarpingData(cost, alignMatrix)
} }

View File

@ -20,8 +20,8 @@ class DTWTest {
val firstSequence: DoubleArray = doubleArrayOf(0.0, 2.0, 3.0, 1.0, 3.0, 0.1, 0.0, 1.0) val firstSequence: DoubleArray = doubleArrayOf(0.0, 2.0, 3.0, 1.0, 3.0, 0.1, 0.0, 1.0)
val secondSequence: DoubleArray = doubleArrayOf(1.0, 0.0, 3.0, 0.0, 0.0, 3.0, 2.0, 0.0, 2.0) val secondSequence: DoubleArray = doubleArrayOf(1.0, 0.0, 3.0, 0.0, 0.0, 3.0, 2.0, 0.0, 2.0)
val seriesOne: Series<Double> = firstSequence.asBuffer().moveTo(0) val seriesOne = firstSequence.asBuffer()
val seriesTwo: Series<Double> = secondSequence.asBuffer().moveTo(0) val seriesTwo = secondSequence.asBuffer()
val result = DoubleFieldOpsND.dynamicTimeWarping(seriesOne, seriesTwo) val result = DoubleFieldOpsND.dynamicTimeWarping(seriesOne, seriesTwo)
println("Total penalty coefficient: ${result.totalCost}") println("Total penalty coefficient: ${result.totalCost}")