Approaching SymEig through SVD
This commit is contained in:
parent
a09a1c7adc
commit
8c1131dd58
@ -67,7 +67,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||
// todo checks
|
||||
checkSymmetric(this)
|
||||
checkSquareMatrix(shape)
|
||||
|
||||
val n = shape.last()
|
||||
@ -113,7 +113,9 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
}
|
||||
|
||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||
TODO()
|
||||
checkSymmetric(this)
|
||||
val svd = this.svd()
|
||||
TODO("U might have some columns negative to V")
|
||||
}
|
||||
|
||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||
|
@ -58,3 +58,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
}
|
||||
}
|
||||
|
||||
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit =
|
||||
check(tensor.eq(tensor.transpose())){
|
||||
"Tensor is not symmetric about the last 2 dimensions"
|
||||
}
|
@ -140,9 +140,9 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
@Ignore
|
||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
||||
val tensorSigma = tensor + tensor.transpose(1, 2)
|
||||
val tensorSigma = tensor + tensor.transpose()
|
||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2))
|
||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user