forked from kscience/kmath
LU and det refactored
This commit is contained in:
parent
2092cc9af4
commit
b46e8c5fe2
@ -10,43 +10,15 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||||
DoubleTensorAlgebra() {
|
DoubleTensorAlgebra() {
|
||||||
|
|
||||||
override fun DoubleTensor.inv(): DoubleTensor = invLU()
|
override fun DoubleTensor.inv(): DoubleTensor = invLU(1e-9)
|
||||||
|
|
||||||
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
override fun DoubleTensor.det(): DoubleTensor = detLU(1e-9)
|
||||||
|
|
||||||
internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> {
|
public fun DoubleTensor.lu(epsilon: Double): Pair<DoubleTensor, IntTensor> =
|
||||||
checkSquareMatrix(shape)
|
computeLU(this, epsilon) ?:
|
||||||
|
throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon")
|
||||||
|
|
||||||
val luTensor = copy()
|
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> = lu(1e-9)
|
||||||
|
|
||||||
val n = shape.size
|
|
||||||
val m = shape.last()
|
|
||||||
val pivotsShape = IntArray(n - 1) { i -> shape[i] }
|
|
||||||
pivotsShape[n - 2] = m + 1
|
|
||||||
|
|
||||||
val pivotsTensor = IntTensor(
|
|
||||||
pivotsShape,
|
|
||||||
IntArray(pivotsShape.reduce(Int::times)) { 0 }
|
|
||||||
)
|
|
||||||
|
|
||||||
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
|
||||||
try {
|
|
||||||
luHelper(lu.as2D(), pivots.as1D(), m)
|
|
||||||
} catch (e: RuntimeException) {
|
|
||||||
if (forDet) {
|
|
||||||
lu.as2D()[intArrayOf(0, 0)] = 0.0
|
|
||||||
} else {
|
|
||||||
throw IllegalStateException("LUP decomposition can't be performed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return Pair(luTensor, pivotsTensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
|
||||||
return luForDet(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun luPivot(
|
override fun luPivot(
|
||||||
luTensor: DoubleTensor,
|
luTensor: DoubleTensor,
|
||||||
@ -79,9 +51,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
|
public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor {
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
checkPositiveDefinite(this)
|
checkPositiveDefinite(this, epsilon)
|
||||||
//checkPositiveDefinite(this, epsilon)
|
|
||||||
|
|
||||||
|
|
||||||
val n = shape.last()
|
val n = shape.last()
|
||||||
val lTensor = zeroesLike()
|
val lTensor = zeroesLike()
|
||||||
@ -146,8 +116,12 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return Pair(eig, v)
|
return Pair(eig, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
public fun DoubleTensor.detLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = luForDet(forDet = true)
|
|
||||||
|
checkSquareMatrix(this.shape)
|
||||||
|
val luTensor = this.copy()
|
||||||
|
val pivotsTensor = this.setUpPivots()
|
||||||
|
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
|
|
||||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||||
@ -160,15 +134,15 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
)
|
)
|
||||||
|
|
||||||
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
|
luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).forEachIndexed { index, (lu, pivots) ->
|
||||||
resBuffer[index] = luMatrixDet(lu.as2D(), pivots.as1D())
|
resBuffer[index] = if (luHelper(lu.as2D(), pivots.as1D(), epsilon))
|
||||||
|
0.0 else luMatrixDet(lu.as2D(), pivots.as1D())
|
||||||
}
|
}
|
||||||
|
|
||||||
return detTensor
|
return detTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.invLU(): DoubleTensor {
|
public fun DoubleTensor.invLU(epsilon: Double = 1e-9): DoubleTensor {
|
||||||
//TODO("Andrei the det is non-zero")
|
val (luTensor, pivotsTensor) = lu(epsilon)
|
||||||
val (luTensor, pivotsTensor) = lu()
|
|
||||||
val invTensor = luTensor.zeroesLike()
|
val invTensor = luTensor.zeroesLike()
|
||||||
|
|
||||||
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence())
|
||||||
|
@ -61,7 +61,13 @@ internal inline fun dotHelper(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
|
internal inline fun luHelper(
|
||||||
|
lu: MutableStructure2D<Double>,
|
||||||
|
pivots: MutableStructure1D<Int>,
|
||||||
|
epsilon: Double): Boolean {
|
||||||
|
|
||||||
|
val m = lu.rowNum
|
||||||
|
|
||||||
for (row in 0..m) pivots[row] = row
|
for (row in 0..m) pivots[row] = row
|
||||||
|
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
@ -69,16 +75,15 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
|
|||||||
var maxInd = i
|
var maxInd = i
|
||||||
|
|
||||||
for (k in i until m) {
|
for (k in i until m) {
|
||||||
val absA = kotlin.math.abs(lu[k, i])
|
val absA = abs(lu[k, i])
|
||||||
if (absA > maxVal) {
|
if (absA > maxVal) {
|
||||||
maxVal = absA
|
maxVal = absA
|
||||||
maxInd = k
|
maxInd = k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (abs(maxVal) < 1e-9) {
|
if (abs(maxVal) < epsilon)
|
||||||
throw RuntimeException()
|
return true // matrix is singular
|
||||||
}
|
|
||||||
|
|
||||||
if (maxInd != i) {
|
if (maxInd != i) {
|
||||||
|
|
||||||
@ -103,6 +108,34 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T> BufferedTensor<T>.setUpPivots(): IntTensor {
|
||||||
|
val n = this.shape.size
|
||||||
|
val m = this.shape.last()
|
||||||
|
val pivotsShape = IntArray(n - 1) { i -> this.shape[i] }
|
||||||
|
pivotsShape[n - 2] = m + 1
|
||||||
|
|
||||||
|
return IntTensor(
|
||||||
|
pivotsShape,
|
||||||
|
IntArray(pivotsShape.reduce(Int::times)) { 0 }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun DoubleLinearOpsTensorAlgebra.computeLU(
|
||||||
|
tensor: DoubleTensor,
|
||||||
|
epsilon: Double): Pair<DoubleTensor, IntTensor>? {
|
||||||
|
|
||||||
|
checkSquareMatrix(tensor.shape)
|
||||||
|
val luTensor = tensor.copy()
|
||||||
|
val pivotsTensor = tensor.setUpPivots()
|
||||||
|
|
||||||
|
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||||
|
if(luHelper(lu.as2D(), pivots.as1D(), epsilon))
|
||||||
|
return null
|
||||||
|
|
||||||
|
return Pair(luTensor, pivotsTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun pivInit(
|
internal inline fun pivInit(
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.structures.toList
|
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.test.Ignore
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
@ -144,11 +142,9 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
|
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding(
|
||||||
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
|
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
|
||||||
)
|
)
|
||||||
//checkPositiveDefinite(sigma) sigma must be positive definite
|
|
||||||
val low = sigma.cholesky()
|
val low = sigma.cholesky()
|
||||||
val sigmChol = low dot low.transpose()
|
val sigmChol = low dot low.transpose()
|
||||||
assertTrue(sigma.eq(sigmChol))
|
assertTrue(sigma.eq(sigmChol))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
Reference in New Issue
Block a user