v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
2 changed files with 11 additions and 6 deletions
Showing only changes of commit 3f0dff3ce9 - Show all commits

View File

@ -3,8 +3,10 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.toList
import kotlin.math.abs
import kotlin.math.min
import kotlin.math.sign
public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
@ -113,11 +115,14 @@ public class DoubleLinearOpsTensorAlgebra :
return Triple(resU.transpose(), resS, resV.transpose())
}
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(this)
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
//see the last point
TODO("maybe use SVD")
val (u, s, v) = this.svd()
val shp = s.shape + intArrayOf(1)
val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) }
val eig = (utv dot s.view(shp)).view(s.shape)
return Pair(eig, v)
}
public fun DoubleTensor.detLU(): DoubleTensor {

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.toList
import kotlin.math.abs
import kotlin.test.Ignore
import kotlin.test.Test
@ -137,13 +138,12 @@ class TestDoubleLinearOpsTensorAlgebra {
}
@Test
@Ignore
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
val tensor = randNormal(shape = intArrayOf(5, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc))
assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01))
}