Add proper test for symmetric matrices eigenValueDecomposition
This commit is contained in:
parent
b818a8981f
commit
1619a49017
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2018-2024 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.linear
|
||||
|
||||
import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import space.kscience.kmath.structures.Float64
|
||||
import kotlin.random.Random
|
||||
|
||||
fun main() {
|
||||
val dim = 46
|
||||
|
||||
val random = Random(123)
|
||||
|
||||
val u = Float64.algebra.linearSpace.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
|
||||
listOf(CMLinearSpace, EjmlLinearSpaceDDRM).forEach { algebra ->
|
||||
with(algebra) {
|
||||
//create a simmetric matrix
|
||||
val matrix = buildMatrix(dim, dim) { row, col ->
|
||||
if (row >= col) u[row, col] else u[col, row]
|
||||
}
|
||||
val eigen = matrix.getOrComputeAttribute(EIG) ?: error("Failed to compute eigenvalue decomposition")
|
||||
check(
|
||||
StructureND.contentEquals(
|
||||
matrix,
|
||||
eigen.v dot eigen.d dot eigen.v.transposed(),
|
||||
1e-4
|
||||
)
|
||||
) { "$algebra decomposition failed" }
|
||||
println("$algebra eigenvalue decomposition complete and checked" )
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -7,6 +7,7 @@ package space.kscience.kmath.linear
|
||||
|
||||
import space.kscience.attributes.PolymorphicAttribute
|
||||
import space.kscience.attributes.safeTypeOf
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
|
||||
public interface EigenDecomposition<T> {
|
||||
/**
|
||||
@ -24,5 +25,6 @@ public class EigenDecompositionAttribute<T> :
|
||||
PolymorphicAttribute<EigenDecomposition<T>>(safeTypeOf()),
|
||||
MatrixAttribute<EigenDecomposition<T>>
|
||||
|
||||
@UnstableKMathAPI
|
||||
public val <T> MatrixScope<T>.EIG: EigenDecompositionAttribute<T>
|
||||
get() = EigenDecompositionAttribute()
|
||||
|
@ -1,15 +1,22 @@
|
||||
plugins {
|
||||
id("space.kscience.gradle.jvm")
|
||||
id("space.kscience.gradle.mpp")
|
||||
}
|
||||
|
||||
val ejmlVerision = "0.43.1"
|
||||
|
||||
dependencies {
|
||||
kscience {
|
||||
jvm()
|
||||
jvmMain {
|
||||
api(projects.kmathCore)
|
||||
api(projects.kmathComplex)
|
||||
api("org.ejml:ejml-all:$ejmlVerision")
|
||||
}
|
||||
|
||||
jvmTest {
|
||||
implementation(projects.testUtils)
|
||||
}
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = space.kscience.gradle.Maturity.PROTOTYPE
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
@ -29,11 +36,3 @@ readme {
|
||||
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt"
|
||||
) { "LinearSpace implementations." }
|
||||
}
|
||||
|
||||
//kotlin.sourceSets.main {
|
||||
// val codegen by tasks.creating {
|
||||
// ejmlCodegen(kotlin.srcDirs.first().absolutePath + "/space/kscience/kmath/ejml/_generated.kt")
|
||||
// }
|
||||
//
|
||||
// kotlin.srcDirs(files().builtBy(codegen))
|
||||
//}
|
||||
|
@ -278,8 +278,8 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
|
||||
}
|
||||
|
||||
override val v: Matrix<Float64> by lazy {
|
||||
val eigenvectors = List(origin.numRows) { cmEigen.getEigenVector(it) }
|
||||
buildMatrix(origin.numRows, origin.numCols) { row, column ->
|
||||
val eigenvectors = List(origin.numRows) { cmEigen.getEigenVector(it) }.filterNotNull()
|
||||
buildMatrix(eigenvectors.size, origin.numCols) { row, column ->
|
||||
eigenvectors[row][column]
|
||||
}
|
||||
}
|
@ -18,12 +18,15 @@ import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.toArray
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import space.kscience.kmath.structures.Float64
|
||||
import space.kscience.kmath.testutils.assertStructureEquals
|
||||
import kotlin.random.Random
|
||||
import kotlin.random.asJavaRandom
|
||||
import kotlin.test.*
|
||||
|
||||
internal fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
||||
assertTrue { StructureND.contentEquals(expected, actual) }
|
||||
expected.elements().forEach { (index, value) ->
|
||||
assertEquals(value, actual[index], "Structure element with index ${index.toList()} should be equal to $value but is ${actual[index]}")
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@ -108,8 +111,12 @@ internal class EjmlMatrixTest {
|
||||
|
||||
@Test
|
||||
fun eigenValueDecomposition() = EjmlLinearSpaceDDRM {
|
||||
val matrix = EjmlDoubleMatrix(randomMatrix)
|
||||
val dim = 46
|
||||
val u = buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||
val matrix = buildMatrix(dim, dim) { row, col ->
|
||||
if (row >= col) u[row, col] else u[col, row]
|
||||
}
|
||||
val eigen = matrix.getOrComputeAttribute(EIG) ?: fail()
|
||||
assertMatrixEquals(matrix, eigen.v dot eigen.d dot eigen.v.transposed())
|
||||
assertStructureEquals(matrix, eigen.v dot eigen.d dot eigen.v.transposed())
|
||||
}
|
||||
}
|
@ -5,6 +5,8 @@
|
||||
|
||||
package space.kscience.kmath.testutils
|
||||
|
||||
import space.kscience.kmath.PerformancePitfall
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.Float64
|
||||
import space.kscience.kmath.structures.indices
|
||||
@ -19,3 +21,11 @@ public fun assertBufferEquals(expected: Buffer<Float64>, result: Buffer<Float64>
|
||||
assertEquals(expected[it], result[it], tolerance)
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public fun assertStructureEquals(expected: StructureND<Float64>, result: StructureND<Float64>, tolerance: Double = 1e-4) {
|
||||
assertEquals(expected.shape, result.shape, "Structure shape mismatch")
|
||||
expected.indices.forEach {
|
||||
assertEquals(expected[it], result[it], tolerance)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user