forked from kscience/kmath
fix det
This commit is contained in:
parent
75783bcb03
commit
c7669d4fba
@ -14,8 +14,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
override fun DoubleTensor.det(): DoubleTensor = detLU()
|
||||||
|
|
||||||
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
internal fun DoubleTensor.luForDet(forDet: Boolean = false): Pair<DoubleTensor, IntTensor> {
|
||||||
|
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
|
|
||||||
val luTensor = copy()
|
val luTensor = copy()
|
||||||
@ -31,10 +30,22 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
)
|
)
|
||||||
|
|
||||||
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
for ((lu, pivots) in luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()))
|
||||||
|
try {
|
||||||
luHelper(lu.as2D(), pivots.as1D(), m)
|
luHelper(lu.as2D(), pivots.as1D(), m)
|
||||||
|
} catch (e: RuntimeException) {
|
||||||
|
if (forDet) {
|
||||||
|
lu.as2D()[intArrayOf(0, 0)] = 0.0
|
||||||
|
} else {
|
||||||
|
throw IllegalStateException("LUP decomposition can't be performed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
return Pair(luTensor, pivotsTensor)
|
return Pair(luTensor, pivotsTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
|
||||||
|
return luForDet(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luPivot(
|
override fun luPivot(
|
||||||
@ -69,8 +80,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||||
checkSymmetric(this)
|
checkSymmetric(this)
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
//TODO("Andrei the det routine has bugs")
|
checkPositiveDefinite(this)
|
||||||
//checkPositiveDefinite(this)
|
|
||||||
|
|
||||||
val n = shape.last()
|
val n = shape.last()
|
||||||
val lTensor = zeroesLike()
|
val lTensor = zeroesLike()
|
||||||
@ -134,7 +144,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||||
val (luTensor, pivotsTensor) = lu()
|
val (luTensor, pivotsTensor) = luForDet(forDet = true)
|
||||||
val n = shape.size
|
val n = shape.size
|
||||||
|
|
||||||
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
val detTensorShape = IntArray(n - 1) { i -> shape[i] }
|
||||||
|
@ -2,6 +2,7 @@ package space.kscience.kmath.tensors.core
|
|||||||
|
|
||||||
import space.kscience.kmath.tensors.TensorAlgebra
|
import space.kscience.kmath.tensors.TensorAlgebra
|
||||||
import space.kscience.kmath.tensors.TensorStructure
|
import space.kscience.kmath.tensors.TensorStructure
|
||||||
|
import kotlin.math.abs
|
||||||
|
|
||||||
|
|
||||||
internal inline fun <T, TensorType : TensorStructure<T>,
|
internal inline fun <T, TensorType : TensorStructure<T>,
|
||||||
@ -69,3 +70,12 @@ internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite(tensor: D
|
|||||||
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun DoubleLinearOpsTensorAlgebra.checkNonSingularMatrix(tensor: DoubleTensor): Unit {
|
||||||
|
for( mat in tensor.matrixSequence()) {
|
||||||
|
val detTensor = mat.asTensor().detLU()
|
||||||
|
check(!(detTensor.eq(detTensor.zeroesLike()))){
|
||||||
|
"Tensor contains matrices which are singular"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -62,10 +62,10 @@ internal inline fun dotHelper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
|
internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStructure1D<Int>, m: Int) {
|
||||||
for (row in 0 until m) pivots[row] = row
|
for (row in 0..m) pivots[row] = row
|
||||||
|
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
var maxVal = -1.0
|
var maxVal = 0.0
|
||||||
var maxInd = i
|
var maxInd = i
|
||||||
|
|
||||||
for (k in i until m) {
|
for (k in i until m) {
|
||||||
@ -76,7 +76,9 @@ internal inline fun luHelper(lu: MutableStructure2D<Double>, pivots: MutableStru
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//todo check singularity
|
if (abs(maxVal) < 1e-9) {
|
||||||
|
throw RuntimeException()
|
||||||
|
}
|
||||||
|
|
||||||
if (maxInd != i) {
|
if (maxInd != i) {
|
||||||
|
|
||||||
@ -158,6 +160,9 @@ internal inline fun choleskyHelper(
|
|||||||
internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
|
internal inline fun luMatrixDet(luTensor: MutableStructure2D<Double>, pivotsTensor: MutableStructure1D<Int>): Double {
|
||||||
val lu = luTensor.as2D()
|
val lu = luTensor.as2D()
|
||||||
val pivots = pivotsTensor.as1D()
|
val pivots = pivotsTensor.as1D()
|
||||||
|
if (lu[0, 0] == 0.0) {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
val m = lu.shape[0]
|
val m = lu.shape[0]
|
||||||
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
val sign = if ((pivots[m] - m) % 2 == 0) 1.0 else -1.0
|
||||||
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right }
|
||||||
|
@ -36,6 +36,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDet() = DoubleLinearOpsTensorAlgebra {
|
fun testDet() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val expectedValue = 0.019827417
|
||||||
val m = fromArray(
|
val m = fromArray(
|
||||||
intArrayOf(3, 3), doubleArrayOf(
|
intArrayOf(3, 3), doubleArrayOf(
|
||||||
2.1843, 1.4391, -0.4845,
|
2.1843, 1.4391, -0.4845,
|
||||||
@ -43,8 +44,20 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
-0.4845, 0.4055, 0.7519
|
-0.4845, 0.4055, 0.7519
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
println(m.det().value())
|
|
||||||
println(0.0197) //expected value
|
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDetSingle() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val expectedValue = 48.151623
|
||||||
|
val m = fromArray(
|
||||||
|
intArrayOf(1, 1), doubleArrayOf(
|
||||||
|
expectedValue
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assertTrue { abs(m.det().value() - expectedValue) < 1e-5}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
32
kmath-tensors/src/jvmMain/kotlin/andMain.kt
Normal file
32
kmath-tensors/src/jvmMain/kotlin/andMain.kt
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import space.kscience.kmath.nd.*
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleLinearOpsTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.array
|
||||||
|
import kotlin.math.abs
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
|
||||||
|
DoubleTensorAlgebra {
|
||||||
|
val tensor = fromArray(
|
||||||
|
intArrayOf(2, 2, 2),
|
||||||
|
doubleArrayOf(
|
||||||
|
1.0, 3.0,
|
||||||
|
1.0, 2.0,
|
||||||
|
1.5, 1.0,
|
||||||
|
10.0, 2.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
val tensor2 = fromArray(
|
||||||
|
intArrayOf(2, 2),
|
||||||
|
doubleArrayOf(
|
||||||
|
0.0, 0.0,
|
||||||
|
0.0, 0.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
DoubleLinearOpsTensorAlgebra {
|
||||||
|
println(tensor2.det().value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user