This commit is contained in:
Andrei Kislitsyn 2021-04-14 23:10:42 +03:00
parent 75783bcb03
commit aeb71b5d27
4 changed files with 56 additions and 14 deletions

View File

@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra :
override fun DoubleTensor.det(): DoubleTensor = detLU()
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> {
checkSquareMatrix(shape)
val luTensor = copy()
@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra :
)
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
try {
luHelper(lu.as2D(), pivots.as1D(), m)
} catch (e: RuntimeException) {
if (forDet) {
lu.as2D()[intArrayOf(0, 0)] = 0.0
} else {
throw IllegalStateException("LUP decomposition can't be performed")
}
}
return Pair(luTensor, pivotsTensor)
}
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
return luForDet(false)
}
override fun luPivot(
@ -66,11 +77,10 @@ public class DoubleLinearOpsTensorAlgebra :
}
override fun DoubleTensor.cholesky(): DoubleTensor {
checkSymmetric(this)
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape)
//TODO("Andrei the det routine has bugs")
//checkPositiveDefinite(this)
checkPositiveDefinite(this)
//checkPositiveDefinite(this, epsilon)
val n = shape.last()
val lTensor = zeroesLike()
@ -81,6 +91,8 @@ public class DoubleLinearOpsTensorAlgebra :
return lTensor
}
override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6)
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape)
val qTensor = zeroesLike()
@ -134,7 +146,7 @@ public class DoubleLinearOpsTensorAlgebra :
}
public fun DoubleTensor.detLU(): DoubleTensor {
val (luTensor, pivotsTensor) = lu()
val (luTensor, pivotsTensor) = luForDet(forDet = true)
val n = shape.size
val detTensorShape = IntArray(n - 1) { i -> shape[i] }

View File

@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.TensorAlgebra
import space.kscience.kmath.tensors.TensorStructure
import kotlin.math.abs
internal inline fun <T, TensorType : TensorStructure<T>,
@ -63,9 +64,20 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
}
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
tensor: DoubleTensor, epsilon: Double = 1e-6): Unit {
checkSymmetric(tensor, epsilon)
for( mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0){
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
}
}
internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit {
for( mat in tensor.matrixSequence()) {
val detTensor = mat.asTensor().detLU()
check(!(detTensor.eq(detTensor.zeroesLike()))){
"Tensor contains matrices which are singular"
}
}
}

View File

@ -62,10 +62,10 @@ internal inline fun dotHelper(
}
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
for (row in 0 until m) pivots[row] = row
for (row in 0..m) pivots[row] = row
for (i in 0 until m) {
var maxVal = -1.0
var maxVal = 0.0
var maxInd = i
for (k in i until m) {
@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
}
}
//todo check singularity
if (abs(maxVal) < 1e-9) {
throw RuntimeException()
}
if (maxInd != i) {
@ -158,6 +160,9 @@ internal inline fun choleskyHelper(
internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
val lu = luTensor.as2D()
val pivots = pivotsTensor.as1D()
if (lu[0, 0] == 0.0) {
return 0.0
}
val m = lu.shape[0]
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }

View File

@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test
fun testDet() = DoubleLinearOpsTensorAlgebra {
val expectedValue = 0.019827417
val m = fromArray(
intArrayOf(3, 3), doubleArrayOf(
2.1843, 1.4391, -0.4845,
@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra {
-0.4845, 0.4055, 0.7519
)
)
println(m.det().value())
println(0.0197) //expected value
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
}
@Test
fun testDetSingle() = DoubleLinearOpsTensorAlgebra {
val expectedValue = 48.151623
val m = fromArray(
intArrayOf(1, 1), doubleArrayOf(
expectedValue
)
)
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
}
@Test