v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
3 changed files with 10 additions and 4 deletions
Showing only changes of commit 8c1131dd58 - Show all commits

View File

@ -67,7 +67,7 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.cholesky(): DoubleTensor { override fun DoubleTensor.cholesky(): DoubleTensor {
// todo checks checkSymmetric(this)
checkSquareMatrix(shape) checkSquareMatrix(shape)
val n = shape.last() val n = shape.last()
@ -113,7 +113,9 @@ public class DoubleLinearOpsTensorAlgebra :
} }
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> { override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
TODO() checkSymmetric(this)
val svd = this.svd()
TODO("U might have some columns negative to V")
} }
public fun DoubleTensor.detLU(): DoubleTensor { public fun DoubleTensor.detLU(): DoubleTensor {

View File

@ -58,3 +58,7 @@ internal inline fun <T, TensorType : TensorStructure<T>,
} }
} }
internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit =
check(tensor.eq(tensor.transpose())){
"Tensor is not symmetric about the last 2 dimensions"
}

View File

@ -140,9 +140,9 @@ class TestDoubleLinearOpsTensorAlgebra {
@Ignore @Ignore
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra { fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0) val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
val tensorSigma = tensor + tensor.transpose(1, 2) val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2)) val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc)) assertTrue(tensorSigma.eq(tensorSigmaCalc))
} }