LU and det refactored

This commit is contained in:
Roland Grinis 2021-04-14 22:13:54 +01:00
parent 2092cc9af4
commit b46e8c5fe2
3 changed files with 59 additions and 56 deletions

View File

@ -10,43 +10,15 @@ public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>, LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
DoubleTensorAlgebra() { DoubleTensorAlgebra() {
override fun DoubleTensor.inv(): DoubleTensor = invLU() override fun DoubleTensor.inv(): DoubleTensor = invLU(1e-9)
override fun DoubleTensor.det(): DoubleTensor = detLU() override fun DoubleTensor.det(): DoubleTensor = detLU(1e-9)
internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> { public fun DoubleTensor.lu(epsilon: Double): Pair<DoubleTensor, IntTensor> =
checkSquareMatrix(shape) computeLU(this, epsilon) ?:
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
val luTensor = copy() override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> = lu(1e-9)
val n = shape.size
val m = shape.last()
val pivotsShape = IntArray(n - 1) { i -> shape[i] }
pivotsShape[n - 2] = m + 1
val pivotsTensor = IntTensor(
pivotsShape,
IntArray(pivotsShape.reduce(Int::times)) { 0 }
)
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
try {
luHelper(lu.as2D(), pivots.as1D(), m)
} catch (e: RuntimeException) {
if (forDet) {
lu.as2D()[intArrayOf(0, 0)] = 0.0
} else {
throw IllegalStateException("LUP decomposition can't be performed")
}
}
return Pair(luTensor, pivotsTensor)
}
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
return luForDet(false)
}
override fun luPivot( override fun luPivot(
luTensor: DoubleTensor, luTensor: DoubleTensor,
@ -79,9 +51,7 @@ public class DoubleLinearOpsTensorAlgebra :
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor { public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape) checkSquareMatrix(shape)
checkPositiveDefinite(this) checkPositiveDefinite(this, epsilon)
//checkPositiveDefinite(this, epsilon)
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()
@ -146,8 +116,12 @@ public class DoubleLinearOpsTensorAlgebra :
return Pair(eig, v) return Pair(eig, v)
} }
public fun DoubleTensor.detLU(): DoubleTensor { public fun DoubleTensor.detLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luForDet(forDet = true)
checkSquareMatrix(this.shape)
val luTensor = this.copy()
val pivotsTensor = this.setUpPivots()
val n = shape.size val n = shape.size
val detTensorShape = IntArray(n - 1) { i -> shape[i] } val detTensorShape = IntArray(n - 1) { i -> shape[i] }
@ -160,15 +134,15 @@ public class DoubleLinearOpsTensorAlgebra :
) )
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) -> luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
resBuffer[index] = luMatrixDet(lu.as2D(), pivots.as1D()) resBuffer[index] = if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
0.0 else luMatrixDet(lu.as2D(), pivots.as1D())
} }
return detTensor return detTensor
} }
public fun DoubleTensor.invLU(): DoubleTensor { public fun DoubleTensor.invLU(epsilon: Double = 1e-9): DoubleTensor {
//TODO("Andrei the det is non-zero") val (luTensor, pivotsTensor) = lu(epsilon)
val (luTensor, pivotsTensor) = lu()
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence()) val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())

View File

@ -61,7 +61,13 @@ internal inline fun dotHelper(
} }
} }
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) { internal inline fun luHelper(
lu: MutableStructure2D<Double>,
pivots: MutableStructure1D<Int>,
epsilon: Double): Boolean {
val m = lu.rowNum
for (row in 0..m) pivots[row] = row for (row in 0..m) pivots[row] = row
for (i in 0 until m) { for (i in 0 until m) {
@ -69,16 +75,15 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
var maxInd = i var maxInd = i
for (k in i until m) { for (k in i until m) {
val absA = kotlin.math.abs(lu[k, i]) val absA = abs(lu[k, i])
if (absA > maxVal) { if (absA > maxVal) {
maxVal = absA maxVal = absA
maxInd = k maxInd = k
} }
} }
if (abs(maxVal) < 1e-9) { if (abs(maxVal) < epsilon)
throw RuntimeException() return true // matrix is singular
}
if (maxInd != i) { if (maxInd != i) {
@ -103,6 +108,34 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
} }
} }
} }
return false
}
internal inline fun <T> BufferedTensor<T>.setUpPivots(): IntTensor {
val n = this.shape.size
val m = this.shape.last()
val pivotsShape = IntArray(n - 1) { i -> this.shape[i] }
pivotsShape[n - 2] = m + 1
return IntTensor(
pivotsShape,
IntArray(pivotsShape.reduce(Int::times)) { 0 }
)
}
internal inline fun DoubleLinearOpsTensorAlgebra.computeLU(
tensor: DoubleTensor,
epsilon: Double): Pair<DoubleTensor, IntTensor>? {
checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots()
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
if(luHelper(lu.as2D(), pivots.as1D(), epsilon))
return null
return Pair(luTensor, pivotsTensor)
} }
internal inline fun pivInit( internal inline fun pivInit(

View File

@ -1,8 +1,6 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.toList
import kotlin.math.abs import kotlin.math.abs
import kotlin.test.Ignore
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -144,11 +142,9 @@ class TestDoubleLinearOpsTensorAlgebra {
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding( val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 }) fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
) )
//checkPositiveDefinite(sigma) sigma must be positive definite
val low = sigma.cholesky() val low = sigma.cholesky()
val sigmChol = low dot low.transpose() val sigmChol = low dot low.transpose()
assertTrue(sigma.eq(sigmChol)) assertTrue(sigma.eq(sigmChol))
} }
@Test @Test