forked from kscience/kmath
Add proper test for symmetric matrices eigenValueDecomposition
This commit is contained in:
parent
b818a8981f
commit
1619a49017
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2024 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
|
import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||||
|
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.algebra
|
||||||
|
import space.kscience.kmath.structures.Float64
|
||||||
|
import kotlin.random.Random
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val dim = 46
|
||||||
|
|
||||||
|
val random = Random(123)
|
||||||
|
|
||||||
|
val u = Float64.algebra.linearSpace.buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
|
||||||
|
listOf(CMLinearSpace, EjmlLinearSpaceDDRM).forEach { algebra ->
|
||||||
|
with(algebra) {
|
||||||
|
//create a simmetric matrix
|
||||||
|
val matrix = buildMatrix(dim, dim) { row, col ->
|
||||||
|
if (row >= col) u[row, col] else u[col, row]
|
||||||
|
}
|
||||||
|
val eigen = matrix.getOrComputeAttribute(EIG) ?: error("Failed to compute eigenvalue decomposition")
|
||||||
|
check(
|
||||||
|
StructureND.contentEquals(
|
||||||
|
matrix,
|
||||||
|
eigen.v dot eigen.d dot eigen.v.transposed(),
|
||||||
|
1e-4
|
||||||
|
)
|
||||||
|
) { "$algebra decomposition failed" }
|
||||||
|
println("$algebra eigenvalue decomposition complete and checked" )
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -7,6 +7,7 @@ package space.kscience.kmath.linear
|
|||||||
|
|
||||||
import space.kscience.attributes.PolymorphicAttribute
|
import space.kscience.attributes.PolymorphicAttribute
|
||||||
import space.kscience.attributes.safeTypeOf
|
import space.kscience.attributes.safeTypeOf
|
||||||
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
|
|
||||||
public interface EigenDecomposition<T> {
|
public interface EigenDecomposition<T> {
|
||||||
/**
|
/**
|
||||||
@ -24,5 +25,6 @@ public class EigenDecompositionAttribute<T> :
|
|||||||
PolymorphicAttribute<EigenDecomposition<T>>(safeTypeOf()),
|
PolymorphicAttribute<EigenDecomposition<T>>(safeTypeOf()),
|
||||||
MatrixAttribute<EigenDecomposition<T>>
|
MatrixAttribute<EigenDecomposition<T>>
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
public val <T> MatrixScope<T>.EIG: EigenDecompositionAttribute<T>
|
public val <T> MatrixScope<T>.EIG: EigenDecompositionAttribute<T>
|
||||||
get() = EigenDecompositionAttribute()
|
get() = EigenDecompositionAttribute()
|
||||||
|
@ -1,15 +1,22 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("space.kscience.gradle.jvm")
|
id("space.kscience.gradle.mpp")
|
||||||
}
|
}
|
||||||
|
|
||||||
val ejmlVerision = "0.43.1"
|
val ejmlVerision = "0.43.1"
|
||||||
|
|
||||||
dependencies {
|
kscience {
|
||||||
|
jvm()
|
||||||
|
jvmMain {
|
||||||
api(projects.kmathCore)
|
api(projects.kmathCore)
|
||||||
api(projects.kmathComplex)
|
api(projects.kmathComplex)
|
||||||
api("org.ejml:ejml-all:$ejmlVerision")
|
api("org.ejml:ejml-all:$ejmlVerision")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
jvmTest {
|
||||||
|
implementation(projects.testUtils)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
maturity = space.kscience.gradle.Maturity.PROTOTYPE
|
maturity = space.kscience.gradle.Maturity.PROTOTYPE
|
||||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
@ -29,11 +36,3 @@ readme {
|
|||||||
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt"
|
ref = "src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt"
|
||||||
) { "LinearSpace implementations." }
|
) { "LinearSpace implementations." }
|
||||||
}
|
}
|
||||||
|
|
||||||
//kotlin.sourceSets.main {
|
|
||||||
// val codegen by tasks.creating {
|
|
||||||
// ejmlCodegen(kotlin.srcDirs.first().absolutePath + "/space/kscience/kmath/ejml/_generated.kt")
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// kotlin.srcDirs(files().builtBy(codegen))
|
|
||||||
//}
|
|
||||||
|
@ -278,8 +278,8 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
|
|||||||
}
|
}
|
||||||
|
|
||||||
override val v: Matrix<Float64> by lazy {
|
override val v: Matrix<Float64> by lazy {
|
||||||
val eigenvectors = List(origin.numRows) { cmEigen.getEigenVector(it) }
|
val eigenvectors = List(origin.numRows) { cmEigen.getEigenVector(it) }.filterNotNull()
|
||||||
buildMatrix(origin.numRows, origin.numCols) { row, column ->
|
buildMatrix(eigenvectors.size, origin.numCols) { row, column ->
|
||||||
eigenvectors[row][column]
|
eigenvectors[row][column]
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -18,12 +18,15 @@ import space.kscience.kmath.nd.StructureND
|
|||||||
import space.kscience.kmath.nd.toArray
|
import space.kscience.kmath.nd.toArray
|
||||||
import space.kscience.kmath.operations.algebra
|
import space.kscience.kmath.operations.algebra
|
||||||
import space.kscience.kmath.structures.Float64
|
import space.kscience.kmath.structures.Float64
|
||||||
|
import space.kscience.kmath.testutils.assertStructureEquals
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.random.asJavaRandom
|
import kotlin.random.asJavaRandom
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
internal fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
internal fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
||||||
assertTrue { StructureND.contentEquals(expected, actual) }
|
expected.elements().forEach { (index, value) ->
|
||||||
|
assertEquals(value, actual[index], "Structure element with index ${index.toList()} should be equal to $value but is ${actual[index]}")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
@ -108,8 +111,12 @@ internal class EjmlMatrixTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun eigenValueDecomposition() = EjmlLinearSpaceDDRM {
|
fun eigenValueDecomposition() = EjmlLinearSpaceDDRM {
|
||||||
val matrix = EjmlDoubleMatrix(randomMatrix)
|
val dim = 46
|
||||||
|
val u = buildMatrix(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
val matrix = buildMatrix(dim, dim) { row, col ->
|
||||||
|
if (row >= col) u[row, col] else u[col, row]
|
||||||
|
}
|
||||||
val eigen = matrix.getOrComputeAttribute(EIG) ?: fail()
|
val eigen = matrix.getOrComputeAttribute(EIG) ?: fail()
|
||||||
assertMatrixEquals(matrix, eigen.v dot eigen.d dot eigen.v.transposed())
|
assertStructureEquals(matrix, eigen.v dot eigen.d dot eigen.v.transposed())
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.testutils
|
package space.kscience.kmath.testutils
|
||||||
|
|
||||||
|
import space.kscience.kmath.PerformancePitfall
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.Float64
|
import space.kscience.kmath.structures.Float64
|
||||||
import space.kscience.kmath.structures.indices
|
import space.kscience.kmath.structures.indices
|
||||||
@ -19,3 +21,11 @@ public fun assertBufferEquals(expected: Buffer<Float64>, result: Buffer<Float64>
|
|||||||
assertEquals(expected[it], result[it], tolerance)
|
assertEquals(expected[it], result[it], tolerance)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
|
public fun assertStructureEquals(expected: StructureND<Float64>, result: StructureND<Float64>, tolerance: Double = 1e-4) {
|
||||||
|
assertEquals(expected.shape, result.shape, "Structure shape mismatch")
|
||||||
|
expected.indices.forEach {
|
||||||
|
assertEquals(expected[it], result[it], tolerance)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user