Approaching SymEig through SVD
This commit is contained in:
parent
a09a1c7adc
commit
8c1131dd58
@ -67,7 +67,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.cholesky(): DoubleTensor {
|
override fun DoubleTensor.cholesky(): DoubleTensor {
|
||||||
// todo checks
|
checkSymmetric(this)
|
||||||
checkSquareMatrix(shape)
|
checkSquareMatrix(shape)
|
||||||
|
|
||||||
val n = shape.last()
|
val n = shape.last()
|
||||||
@ -113,7 +113,9 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||||
TODO()
|
checkSymmetric(this)
|
||||||
|
val svd = this.svd()
|
||||||
|
TODO("U might have some columns negative to V")
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||||
|
@ -58,3 +58,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit =
|
||||||
|
check(tensor.eq(tensor.transpose())){
|
||||||
|
"Tensor is not symmetric about the last 2 dimensions"
|
||||||
|
}
|
@ -140,9 +140,9 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
@Ignore
|
@Ignore
|
||||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
||||||
val tensorSigma = tensor + tensor.transpose(1, 2)
|
val tensorSigma = tensor + tensor.transpose()
|
||||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2))
|
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user