v0.3.0-dev-9 #324
@ -67,9 +67,10 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||||
//positive definite check
|
|
||||||
checkSymmetric(this)
|
checkSymmetric(this)
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
|
//TODO("Andrei the det routine has bugs")
|
||||||
|
//checkPositiveDefinite(this)
|
||||||
|
|
||||||
val n = shape.last()
|
val n = shape.last()
|
||||||
val lTensor = zeroesLike()
|
val lTensor = zeroesLike()
|
||||||
|
@ -62,3 +62,10 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
|
|||||||
check(tensor.eq(tensor.transpose(), epsilon)) {
|
check(tensor.eq(tensor.transpose(), epsilon)) {
|
||||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
|
||||||
|
for( mat in tensor.matrixSequence())
|
||||||
|
check(mat.asTensor().detLU().value() > 0.0){
|
||||||
|
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||||
|
}
|
||||||
|
}
|
@ -112,6 +112,19 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
assertTrue((p dot tensor).eq(l dot u))
|
assertTrue((p dot tensor).eq(l dot u))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testCholesky() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val tensor = randNormal(intArrayOf(2, 5, 5), 0)
|
||||||
|
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
|
||||||
|
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
|
||||||
|
)
|
||||||
|
//checkPositiveDefinite(sigma) sigma must be positive definite
|
||||||
|
val low = sigma.cholesky()
|
||||||
|
val sigmChol = low dot low.transpose()
|
||||||
|
assertTrue(sigma.eq(sigmChol))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSVD1D() = DoubleLinearOpsTensorAlgebra {
|
fun testSVD1D() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
Loading…
Reference in New Issue
Block a user