Fix det #276
@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
||||
|
||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||
|
||||
internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> {
|
||||
checkSquareMatrix(shape)
|
||||
|
||||
val luTensor = copy()
|
||||
@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
)
|
||||
|
||||
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||
luHelper(lu.as2D(), pivots.as1D(), m)
|
||||
try {
|
||||
luHelper(lu.as2D(), pivots.as1D(), m)
|
||||
} catch (e: RuntimeException) {
|
||||
if (forDet) {
|
||||
lu.as2D()[intArrayOf(0, 0)] = 0.0
|
||||
} else {
|
||||
throw IllegalStateException("LUP decomposition can't be performed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return Pair(luTensor, pivotsTensor)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||
return luForDet(false)
|
||||
}
|
||||
|
||||
override fun luPivot(
|
||||
@ -69,8 +80,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||
checkSymmetric(this)
|
||||
checkSquareMatrix(shape)
|
||||
//TODO("Andrei the det routine has bugs")
|
||||
//checkPositiveDefinite(this)
|
||||
checkPositiveDefinite(this)
|
||||
|
||||
val n = shape.last()
|
||||
val lTensor = zeroesLike()
|
||||
@ -134,7 +144,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||
val (luTensor, pivotsTensor) = lu()
|
||||
val (luTensor, pivotsTensor) = luForDet(forDet = true)
|
||||
val n = shape.size
|
||||
|
||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||
|
@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.tensors.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.TensorStructure
|
||||
import kotlin.math.abs
|
||||
|
||||
|
||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
@ -69,3 +70,12 @@ internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: D
|
||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit {
|
||||
for( mat in tensor.matrixSequence()) {
|
||||
val detTensor = mat.asTensor().detLU()
|
||||
check(!(detTensor.eq(detTensor.zeroesLike()))){
|
||||
"Tensor contains matrices which are singular"
|
||||
}
|
||||
}
|
||||
}
|
@ -62,10 +62,10 @@ internal inline fun dotHelper(
|
||||
}
|
||||
|
||||
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
|
||||
for (row in 0 until m) pivots[row] = row
|
||||
for (row in 0..m) pivots[row] = row
|
||||
|
||||
for (i in 0 until m) {
|
||||
var maxVal = -1.0
|
||||
var maxVal = 0.0
|
||||
var maxInd = i
|
||||
|
||||
for (k in i until m) {
|
||||
@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
|
||||
}
|
||||
}
|
||||
|
||||
//todo check singularity
|
||||
if (abs(maxVal) < 1e-9) {
|
||||
throw RuntimeException()
|
||||
}
|
||||
|
||||
if (maxInd != i) {
|
||||
|
||||
@ -158,6 +160,9 @@ internal inline fun choleskyHelper(
|
||||
internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
|
||||
val lu = luTensor.as2D()
|
||||
val pivots = pivotsTensor.as1D()
|
||||
if (lu[0, 0] == 0.0) {
|
||||
return 0.0
|
||||
}
|
||||
val m = lu.shape[0]
|
||||
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||
|
@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testDet() = DoubleLinearOpsTensorAlgebra {
|
||||
val expectedValue = 0.019827417
|
||||
val m = fromArray(
|
||||
intArrayOf(3, 3), doubleArrayOf(
|
||||
2.1843, 1.4391, -0.4845,
|
||||
@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
-0.4845, 0.4055, 0.7519
|
||||
)
|
||||
)
|
||||
println(m.det().value())
|
||||
println(0.0197) //expected value
|
||||
|
||||
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDetSingle() = DoubleLinearOpsTensorAlgebra {
|
||||
val expectedValue = 48.151623
|
||||
val m = fromArray(
|
||||
intArrayOf(1, 1), doubleArrayOf(
|
||||
expectedValue
|
||||
)
|
||||
)
|
||||
|
||||
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
32
kmath-tensors/src/jvmMain/kotlin/andMain.kt
Normal file
32
kmath-tensors/src/jvmMain/kotlin/andMain.kt
Normal file
@ -0,0 +1,32 @@
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.tensors.core.DoubleLinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.array
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
fun main() {
|
||||
|
||||
DoubleTensorAlgebra {
|
||||
val tensor = fromArray(
|
||||
intArrayOf(2, 2, 2),
|
||||
doubleArrayOf(
|
||||
1.0, 3.0,
|
||||
1.0, 2.0,
|
||||
1.5, 1.0,
|
||||
10.0, 2.0
|
||||
)
|
||||
)
|
||||
val tensor2 = fromArray(
|
||||
intArrayOf(2, 2),
|
||||
doubleArrayOf(
|
||||
0.0, 0.0,
|
||||
0.0, 0.0
|
||||
)
|
||||
)
|
||||
DoubleLinearOpsTensorAlgebra {
|
||||
println(tensor2.det().value())
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user