diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt index d980c510f..87c459f35 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt @@ -1,7 +1,7 @@ package space.kscience.kmath.tensors -public interface LinearOpsTensorAlgebra, IndexTensorType: TensorStructure> : +public interface LinearOpsTensorAlgebra, IndexTensorType : TensorStructure> : TensorPartialDivisionAlgebra { //https://pytorch.org/docs/stable/linalg.html#torch.linalg.det @@ -26,6 +26,6 @@ public interface LinearOpsTensorAlgebra, Inde public fun TensorType.svd(): Triple //https://pytorch.org/docs/stable/generated/torch.symeig.html - public fun TensorType.symEig(eigenvectors: Boolean = true): Pair + public fun TensorType.symEig(): Pair } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 268f2fe55..95b668917 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -94,8 +94,10 @@ public class DoubleLinearOpsTensorAlgebra : return qTensor to rTensor } + override fun DoubleTensor.svd(): Triple = + svd(epsilon = 1e-10) - override fun DoubleTensor.svd(): Triple { + public fun DoubleTensor.svd(epsilon: Double): Triple { val size = this.linearStructure.dim val commonShape = this.shape.sliceArray(0 until size - 2) val (n, m) = this.shape.sliceArray(size - 2 until size) @@ -110,17 +112,20 @@ public class DoubleLinearOpsTensorAlgebra : matrix.shape, matrix.buffer.array().slice(matrix.bufferStart until matrix.bufferStart + size).toDoubleArray() ) - svdHelper(curMatrix, USV, m, n) + svdHelper(curMatrix, USV, m, n, epsilon) } return Triple(resU.transpose(), resS, resV.transpose()) } + override fun DoubleTensor.symEig(): Pair = + symEig(epsilon = 1e-15) + //http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html - override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair { - checkSymmetric(this) - val (u, s, v) = this.svd() + public fun DoubleTensor.symEig(epsilon: Double): Pair { + checkSymmetric(this, epsilon) + val (u, s, v) = this.svd(epsilon) val shp = s.shape + intArrayOf(1) - val utv = (u.transpose() dot v).map { if (abs(it) < 0.99) 0.0 else sign(it) } + val utv = (u.transpose() dot v).map { if (abs(it) < 0.9) 0.0 else sign(it) } val eig = (utv dot s.view(shp)).view(s.shape) return Pair(eig, v) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index 1f70eb300..1dde2ea56 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -58,7 +58,7 @@ internal inline fun , } } -internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor): Unit = - check(tensor.eq(tensor.transpose())){ - "Tensor is not symmetric about the last 2 dimensions" +internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit = + check(tensor.eq(tensor.transpose(), epsilon)) { + "Tensor is not symmetric about the last 2 dimensions at precision $epsilon" } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt index 1bdfee2d5..685c16a1b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt @@ -246,7 +246,7 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svd1d(a: DoubleTensor, epsilon: internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper( matrix: DoubleTensor, USV: Pair, Pair, BufferedTensor>>, - m: Int, n: Int + m: Int, n: Int, epsilon: Double ): Unit { val res = ArrayList>(0) val (matrixU, SV) = USV @@ -267,12 +267,12 @@ internal inline fun DoubleLinearOpsTensorAlgebra.svdHelper( var u: DoubleTensor var norm: Double if (n > m) { - v = svd1d(a) + v = svd1d(a, epsilon) u = matrix.dot(v) norm = DoubleAnalyticTensorAlgebra { (u dot u).sqrt().value() } u = u.times(1.0 / norm) } else { - u = svd1d(a) + u = svd1d(a, epsilon) v = matrix.transpose(0, 1).dot(u) norm = DoubleAnalyticTensorAlgebra { (v dot v).sqrt().value() } v = v.times(1.0 / norm) diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index c96fcb1c0..0997a9b86 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -143,7 +143,7 @@ class TestDoubleLinearOpsTensorAlgebra { val tensorSigma = tensor + tensor.transpose() val (tensorS, tensorV) = tensorSigma.symEig() val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) - assertTrue(tensorSigma.eq(tensorSigmaCalc, 0.01)) + assertTrue(tensorSigma.eq(tensorSigmaCalc)) }