This commit is contained in:
Andrei Kislitsyn 2021-04-14 23:10:42 +03:00
parent 75783bcb03
commit aeb71b5d27
4 changed files with 56 additions and 14 deletions

View File

@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra :
override fun DoubleTensor.det(): DoubleTensor = detLU() override fun DoubleTensor.det(): DoubleTensor = detLU()
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> { internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
val luTensor = copy() val luTensor = copy()
@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra :
) )
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence())) for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
luHelper(lu.as2D(), pivots.as1D(), m) try {
luHelper(lu.as2D(), pivots.as1D(), m)
} catch (e: RuntimeException) {
if (forDet) {
lu.as2D()[intArrayOf(0, 0)] = 0.0
} else {
throw IllegalStateException("LUP decomposition can't be performed")
}
}
return Pair(luTensor, pivotsTensor) return Pair(luTensor, pivotsTensor)
}
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
return luForDet(false)
} }
override fun luPivot( override fun luPivot(
@ -66,11 +77,10 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.cholesky(): DoubleTensor { public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
checkSymmetric(this)
checkSquareMatrix(shape) checkSquareMatrix(shape)
//TODO("Andrei the det routine has bugs") checkPositiveDefinite(this)
//checkPositiveDefinite(this) //checkPositiveDefinite(this, epsilon)
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()
@ -81,6 +91,8 @@ public class DoubleLinearOpsTensorAlgebra :
return lTensor return lTensor
} }
override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6)
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> { override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()
@ -134,7 +146,7 @@ public class DoubleLinearOpsTensorAlgebra :
} }
public fun DoubleTensor.detLU(): DoubleTensor { public fun DoubleTensor.detLU(): DoubleTensor {
val (luTensor, pivotsTensor) = lu() val (luTensor, pivotsTensor) = luForDet(forDet = true)
val n = shape.size val n = shape.size
val detTensorShape = IntArray(n - 1) { i -> shape[i] } val detTensorShape = IntArray(n - 1) { i -> shape[i] }

View File

@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.TensorAlgebra import space.kscience.kmath.tensors.TensorAlgebra
import space.kscience.kmath.tensors.TensorStructure import space.kscience.kmath.tensors.TensorStructure
import kotlin.math.abs
internal inline fun <T, TensorType : TensorStructure<T>, internal inline fun <T, TensorType : TensorStructure<T>,
@ -63,9 +64,20 @@ internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, eps
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon" "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"
} }
internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor): Unit { internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(
tensor: DoubleTensor, epsilon: Double = 1e-6): Unit {
checkSymmetric(tensor, epsilon)
for( mat in tensor.matrixSequence()) for( mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0){ check(mat.asTensor().detLU().value() > 0.0){
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
} }
} }
internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit {
for( mat in tensor.matrixSequence()) {
val detTensor = mat.asTensor().detLU()
check(!(detTensor.eq(detTensor.zeroesLike()))){
"Tensor contains matrices which are singular"
}
}
}

View File

@ -62,10 +62,10 @@ internal inline fun dotHelper(
} }
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) { internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
for (row in 0 until m) pivots[row] = row for (row in 0..m) pivots[row] = row
for (i in 0 until m) { for (i in 0 until m) {
var maxVal = -1.0 var maxVal = 0.0
var maxInd = i var maxInd = i
for (k in i until m) { for (k in i until m) {
@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
} }
} }
//todo check singularity if (abs(maxVal) < 1e-9) {
throw RuntimeException()
}
if (maxInd != i) { if (maxInd != i) {
@ -158,6 +160,9 @@ internal inline fun choleskyHelper(
internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double { internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
val lu = luTensor.as2D() val lu = luTensor.as2D()
val pivots = pivotsTensor.as1D() val pivots = pivotsTensor.as1D()
if (lu[0, 0] == 0.0) {
return 0.0
}
val m = lu.shape[0] val m = lu.shape[0]
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0 val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right } return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }

View File

@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra {
@Test @Test
fun testDet() = DoubleLinearOpsTensorAlgebra { fun testDet() = DoubleLinearOpsTensorAlgebra {
val expectedValue = 0.019827417
val m = fromArray( val m = fromArray(
intArrayOf(3, 3), doubleArrayOf( intArrayOf(3, 3), doubleArrayOf(
2.1843, 1.4391, -0.4845, 2.1843, 1.4391, -0.4845,
@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra {
-0.4845, 0.4055, 0.7519 -0.4845, 0.4055, 0.7519
) )
) )
println(m.det().value())
println(0.0197) //expected value assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
}
@Test
fun testDetSingle() = DoubleLinearOpsTensorAlgebra {
val expectedValue = 48.151623
val m = fromArray(
intArrayOf(1, 1), doubleArrayOf(
expectedValue
)
)
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
} }
@Test @Test