KMP library for tensors #300
@ -67,9 +67,10 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||
//positive definite check
|
||||
checkSymmetric(this)
|
||||
checkSquareMatrix(shape)
|
||||
//TODO("Andrei the det routine has bugs")
|
||||
//checkPositiveDefinite(this)
|
||||
|
||||
val n = shape.last()
|
||||
val lTensor = zeroesLike()
|
||||
|
@ -62,3 +62,10 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
|
||||
check(tensor.eq(tensor.transpose(), epsilon)) {
|
||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||
}
|
||||
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
|
||||
for( mat in tensor.matrixSequence())
|
||||
check(mat.asTensor().detLU().value() > 0.0){
|
||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||
}
|
||||
}
|
@ -112,6 +112,19 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
assertTrue((p dot tensor).eq(l dot u))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCholesky() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randNormal(intArrayOf(2, 5, 5), 0)
|
||||
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
|
||||
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
|
||||
)
|
||||
//checkPositiveDefinite(sigma) sigma must be positive definite
|
||||
val low = sigma.cholesky()
|
||||
val sigmChol = low dot low.transpose()
|
||||
assertTrue(sigma.eq(sigmChol))
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSVD1D() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
|
Loading…
Reference in New Issue
Block a user