fix det
This commit is contained in:
parent
75783bcb03
commit
aeb71b5d27
@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
||||
|
||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||
|
||||
internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> {
|
||||
checkSquareMatrix(shape)
|
||||
|
||||
val luTensor = copy()
|
||||
@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
)
|
||||
|
||||
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||
try {
|
||||
luHelper(lu.as2D(), pivots.as1D(), m)
|
||||
} catch (e: RuntimeException) {
|
||||
if (forDet) {
|
||||
lu.as2D()[intArrayOf(0, 0)] = 0.0
|
||||
} else {
|
||||
throw IllegalStateException("LUP decomposition can't be performed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return Pair(luTensor, pivotsTensor)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||
return luForDet(false)
|
||||
}
|
||||
|
||||
override fun luPivot(
|
||||
@ -66,11 +77,10 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||
checkSymmetric(this)
|
||||
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
|
||||
checkSquareMatrix(shape)
|
||||
//TODO("Andrei the det routine has bugs")
|
||||
//checkPositiveDefinite(this)
|
||||
checkPositiveDefinite(this)
|
||||
//checkPositiveDefinite(this, epsilon)
|
||||
|
||||
val n = shape.last()
|
||||
val lTensor = zeroesLike()
|
||||
@ -81,6 +91,8 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
return lTensor
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6)
|
||||
|
||||
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(shape)
|
||||
val qTensor = zeroesLike()
|
||||
@ -134,7 +146,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||
val (luTensor, pivotsTensor) = lu()
|
||||
val (luTensor, pivotsTensor) = luForDet(forDet = true)
|
||||
val n = shape.size
|
||||
|
||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||
|
@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.TensorStructure
|
||||
import kotlin.math.abs
|
||||
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
@ -63,9 +64,20 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
|
||||
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
|
||||
}
|
||||
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit {
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
|
||||
tensor: DoubleTensor, epsilon: Double = 1e-6): Unit {
|
||||
checkSymmetric(tensor, epsilon)
|
||||
for( mat in tensor.matrixSequence())
|
||||
check(mat.asTensor().detLU().value() > 0.0){
|
||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit {
|
||||
for( mat in tensor.matrixSequence()) {
|
||||
val detTensor = mat.asTensor().detLU()
|
||||
check(!(detTensor.eq(detTensor.zeroesLike()))){
|
||||
"Tensor contains matrices which are singular"
|
||||
}
|
||||
}
|
||||
}
|
@ -62,10 +62,10 @@ internal inline fun dotHelper(
|
||||
}
|
||||
|
||||
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
|
||||
for (row in 0 until m) pivots[row] = row
|
||||
for (row in 0..m) pivots[row] = row
|
||||
|
||||
for (i in 0 until m) {
|
||||
var maxVal = -1.0
|
||||
var maxVal = 0.0
|
||||
var maxInd = i
|
||||
|
||||
for (k in i until m) {
|
||||
@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
|
||||
}
|
||||
}
|
||||
|
||||
//todo check singularity
|
||||
if (abs(maxVal) < 1e-9) {
|
||||
throw RuntimeException()
|
||||
}
|
||||
|
||||
if (maxInd != i) {
|
||||
|
||||
@ -158,6 +160,9 @@ internal inline fun choleskyHelper(
|
||||
internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
|
||||
val lu = luTensor.as2D()
|
||||
val pivots = pivotsTensor.as1D()
|
||||
if (lu[0, 0] == 0.0) {
|
||||
return 0.0
|
||||
}
|
||||
val m = lu.shape[0]
|
||||
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||
|
@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testDet() = DoubleLinearOpsTensorAlgebra {
|
||||
val expectedValue = 0.019827417
|
||||
val m = fromArray(
|
||||
intArrayOf(3, 3), doubleArrayOf(
|
||||
2.1843, 1.4391, -0.4845,
|
||||
@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
-0.4845, 0.4055, 0.7519
|
||||
)
|
||||
)
|
||||
println(m.det().value())
|
||||
println(0.0197) //expected value
|
||||
|
||||
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDetSingle() = DoubleLinearOpsTensorAlgebra {
|
||||
val expectedValue = 48.151623
|
||||
val m = fromArray(
|
||||
intArrayOf(1, 1), doubleArrayOf(
|
||||
expectedValue
|
||||
)
|
||||
)
|
||||
|
||||
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
Loading…
Reference in New Issue
Block a user