KMP library for tensors #300
@ -3,8 +3,10 @@ package space.kscience.kmath.tensors.core
|
|||||||
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.LinearOpsTensorAlgebra
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
|
import space.kscience.kmath.structures.toList
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
|
import kotlin.math.sign
|
||||||
|
|
||||||
public class DoubleLinearOpsTensorAlgebra :
|
public class DoubleLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||||
@ -113,11 +115,14 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return Triple(resU.transpose(), resS, resV.transpose())
|
return Triple(resU.transpose(), resS, resV.transpose())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
||||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||||
checkSymmetric(this)
|
checkSymmetric(this)
|
||||||
//http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
|
val (u, s, v) = this.svd()
|
||||||
//see the last point
|
val shp = s.shape + intArrayOf(1)
|
||||||
TODO("maybe use SVD")
|
val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) }
|
||||||
|
val eig = (utv dot s.view(shp)).view(s.shape)
|
||||||
|
return Pair(eig, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.toList
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.test.Ignore
|
import kotlin.test.Ignore
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -137,13 +138,12 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
|
||||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
val tensor = randNormal(shape = intArrayOf(5, 3, 3), 0)
|
||||||
val tensorSigma = tensor + tensor.transpose()
|
val tensorSigma = tensor + tensor.transpose()
|
||||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user