v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
2 changed files with 11 additions and 6 deletions
Showing only changes of commit 3f0dff3ce9 - Show all commits

View File

@ -3,8 +3,10 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.structures.toList
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.min import kotlin.math.min
import kotlin.math.sign
public class DoubleLinearOpsTensorAlgebra : public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>, LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
@ -113,11 +115,14 @@ public class DoubleLinearOpsTensorAlgebra :
return Triple(resU.transpose(), resS, resV.transpose()) return Triple(resU.transpose(), resS, resV.transpose())
} }
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> { override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(this) checkSymmetric(this)
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html val (u, s, v) = this.svd()
//see the last point val shp = s.shape + intArrayOf(1)
TODO("maybe use SVD") val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) }
val eig = (utv dot s.view(shp)).view(s.shape)
return Pair(eig, v)
} }
public fun DoubleTensor.detLU(): DoubleTensor { public fun DoubleTensor.detLU(): DoubleTensor {

View File

@ -1,5 +1,6 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.toList
import kotlin.math.abs import kotlin.math.abs
import kotlin.test.Ignore import kotlin.test.Ignore
import kotlin.test.Test import kotlin.test.Test
@ -137,13 +138,12 @@ class TestDoubleLinearOpsTensorAlgebra {
} }
@Test @Test
@Ignore
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra { fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0) val tensor = randNormal(shape = intArrayOf(5, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose() val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc)) assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01))
} }