forked from kscience/kmath
Enum class and some code corrections.
This commit is contained in:
parent
11cbf7cdc7
commit
d5c3aa563e
@ -9,23 +9,27 @@ import space.kscience.kmath.PerformancePitfall
|
|||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.DoubleBufferND
|
import space.kscience.kmath.nd.DoubleBufferND
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Offset constants which will be used later. Added them for avoiding "magical numbers" problem.
|
* Offset constants which will be used later. Added them for avoiding "magical numbers" problem.
|
||||||
*/
|
*/
|
||||||
internal const val LEFT_OFFSET : Int = -1
|
internal enum class DtwOffset {
|
||||||
internal const val BOTTOM_OFFSET : Int = 1
|
LEFT,
|
||||||
internal const val DIAGONAL_OFFSET : Int = 0
|
BOTTOM,
|
||||||
|
DIAGONAL
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Public class to store result of method. Class contains total penalty cost for series alignment.
|
* Public class to store result of method. Class contains total penalty cost for series alignment.
|
||||||
* Also, this class contains align matrix (which point of the first series matches to point of the other series).
|
* Also, this class contains align matrix (which point of the first series matches to point of the other series).
|
||||||
*/
|
*/
|
||||||
public data class DynamicTimeWarpingData(val totalCost : Double = 0.0,
|
public data class DynamicTimeWarpingData(
|
||||||
val alignMatrix : DoubleBufferND = DoubleFieldOpsND.structureND(ShapeND(0, 0)) { (_, _) -> 0.0})
|
val totalCost : Double = 0.0,
|
||||||
|
val alignMatrix : DoubleBufferND = DoubleFieldOpsND.structureND(ShapeND(0, 0)) { (_, _) -> 0.0}
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PathIndices class for better code perceptibility.
|
* PathIndices class for better code perceptibility.
|
||||||
@ -33,15 +37,14 @@ public data class DynamicTimeWarpingData(val totalCost : Double = 0.0,
|
|||||||
* is flags for bottom, diagonal or left offsets respectively.
|
* is flags for bottom, diagonal or left offsets respectively.
|
||||||
*/
|
*/
|
||||||
internal data class PathIndices (var id_x: Int, var id_y: Int) {
|
internal data class PathIndices (var id_x: Int, var id_y: Int) {
|
||||||
fun moveOption (direction: Int) {
|
fun moveOption (direction: DtwOffset) {
|
||||||
when(direction) {
|
when(direction) {
|
||||||
BOTTOM_OFFSET -> id_x--
|
DtwOffset.BOTTOM -> id_x--
|
||||||
DIAGONAL_OFFSET -> {
|
DtwOffset.DIAGONAL -> {
|
||||||
id_x--
|
id_x--
|
||||||
id_y--
|
id_y--
|
||||||
}
|
}
|
||||||
LEFT_OFFSET -> id_y--
|
DtwOffset.LEFT -> id_y--
|
||||||
else -> throw Exception("There is no such offset flag!")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -51,7 +54,7 @@ internal data class PathIndices (var id_x: Int, var id_y: Int) {
|
|||||||
* for two series comparing and penalty for this alignment.
|
* for two series comparing and penalty for this alignment.
|
||||||
*/
|
*/
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2 : Series<Double>) : DynamicTimeWarpingData {
|
public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : DoubleBuffer, series2 : DoubleBuffer) : DynamicTimeWarpingData {
|
||||||
var cost = 0.0
|
var cost = 0.0
|
||||||
var pathLength = 0
|
var pathLength = 0
|
||||||
// Special matrix of costs alignment for two series.
|
// Special matrix of costs alignment for two series.
|
||||||
@ -76,7 +79,7 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2
|
|||||||
// alignMatrix contains non-zero values at position where two points from series matches
|
// alignMatrix contains non-zero values at position where two points from series matches
|
||||||
// Values are penalty for concatenation of current points.
|
// Values are penalty for concatenation of current points.
|
||||||
val alignMatrix = structureND(ShapeND(series1.size, series2.size)) {(_, _) -> 0.0}
|
val alignMatrix = structureND(ShapeND(series1.size, series2.size)) {(_, _) -> 0.0}
|
||||||
val indexes = PathIndices(alignMatrix.indices.shape.first() - 1, alignMatrix.indices.shape.last() - 1)
|
val indexes = PathIndices(series1.size - 1, series2.size - 1)
|
||||||
|
|
||||||
with(indexes) {
|
with(indexes) {
|
||||||
alignMatrix[id_x, id_y] = costMatrix[id_x, id_y]
|
alignMatrix[id_x, id_y] = costMatrix[id_x, id_y]
|
||||||
@ -85,13 +88,13 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2
|
|||||||
while (id_x != 0 || id_y != 0) {
|
while (id_x != 0 || id_y != 0) {
|
||||||
when {
|
when {
|
||||||
id_x == 0 || costMatrix[id_x, id_y] == costMatrix[id_x, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> {
|
id_x == 0 || costMatrix[id_x, id_y] == costMatrix[id_x, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> {
|
||||||
moveOption(LEFT_OFFSET)
|
moveOption(DtwOffset.LEFT)
|
||||||
}
|
}
|
||||||
id_y == 0 || costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y] + abs(series1[id_x] - series2[id_y]) -> {
|
id_y == 0 || costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y] + abs(series1[id_x] - series2[id_y]) -> {
|
||||||
moveOption(BOTTOM_OFFSET)
|
moveOption(DtwOffset.BOTTOM)
|
||||||
}
|
}
|
||||||
costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> {
|
costMatrix[id_x, id_y] == costMatrix[id_x - 1, id_y - 1] + abs(series1[id_x] - series2[id_y]) -> {
|
||||||
moveOption(DIAGONAL_OFFSET)
|
moveOption(DtwOffset.DIAGONAL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
alignMatrix[id_x, id_y] = costMatrix[id_x, id_y]
|
alignMatrix[id_x, id_y] = costMatrix[id_x, id_y]
|
||||||
@ -103,3 +106,4 @@ public fun DoubleFieldOpsND.dynamicTimeWarping(series1 : Series<Double>, series2
|
|||||||
return DynamicTimeWarpingData(cost, alignMatrix)
|
return DynamicTimeWarpingData(cost, alignMatrix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +20,8 @@ class DTWTest {
|
|||||||
val firstSequence: DoubleArray = doubleArrayOf(0.0, 2.0, 3.0, 1.0, 3.0, 0.1, 0.0, 1.0)
|
val firstSequence: DoubleArray = doubleArrayOf(0.0, 2.0, 3.0, 1.0, 3.0, 0.1, 0.0, 1.0)
|
||||||
val secondSequence: DoubleArray = doubleArrayOf(1.0, 0.0, 3.0, 0.0, 0.0, 3.0, 2.0, 0.0, 2.0)
|
val secondSequence: DoubleArray = doubleArrayOf(1.0, 0.0, 3.0, 0.0, 0.0, 3.0, 2.0, 0.0, 2.0)
|
||||||
|
|
||||||
val seriesOne: Series<Double> = firstSequence.asBuffer().moveTo(0)
|
val seriesOne = firstSequence.asBuffer()
|
||||||
val seriesTwo: Series<Double> = secondSequence.asBuffer().moveTo(0)
|
val seriesTwo = secondSequence.asBuffer()
|
||||||
|
|
||||||
val result = DoubleFieldOpsND.dynamicTimeWarping(seriesOne, seriesTwo)
|
val result = DoubleFieldOpsND.dynamicTimeWarping(seriesOne, seriesTwo)
|
||||||
println("Total penalty coefficient: ${result.totalCost}")
|
println("Total penalty coefficient: ${result.totalCost}")
|
||||||
|
Loading…
Reference in New Issue
Block a user