KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
3 changed files with 23 additions and 2 deletions
Showing only changes of commit 1e8da7a87b - Show all commits

View File

@ -67,9 +67,10 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.cholesky(): DoubleTensor { override fun DoubleTensor.cholesky(): DoubleTensor {
//positive definite check
checkSymmetric(this) checkSymmetric(this)
checkSquareMatrix(shape) checkSquareMatrix(shape)
//TODO("Andrei the det routine has bugs")
//checkPositiveDefinite(this)
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()

View File

@ -61,4 +61,11 @@ internal inline fun <T, TensorType : TensorStructure<T>,
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit = internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit =
check(tensor.eq(tensor.transpose(), epsilon)) { check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon" "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
} }
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
for( mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0){
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
}
}

View File

@ -112,6 +112,19 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue((p dot tensor).eq(l dot u)) assertTrue((p dot tensor).eq(l dot u))
} }
@Test
fun testCholesky() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(2, 5, 5), 0)
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
)
//checkPositiveDefinite(sigma) sigma must be positive definite
val low = sigma.cholesky()
val sigmChol = low dot low.transpose()
assertTrue(sigma.eq(sigmChol))
}
@Test @Test
fun testSVD1D() = DoubleLinearOpsTensorAlgebra { fun testSVD1D() = DoubleLinearOpsTensorAlgebra {
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))