diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 742a3d7a7..d83c70bd9 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -837,7 +837,7 @@ public open class DoubleTensorAlgebra : return this.svdGolubKahan() } - public fun StructureND.svdGolubKahan(iterations: Int = 30): Triple { + public fun StructureND.svdGolubKahan(iterations: Int = 30, epsilon: Double = 1e-10): Triple { val size = tensor.dimension val commonShape = tensor.shape.sliceArray(0 until size - 2) val (n, m) = tensor.shape.sliceArray(size - 2 until size) @@ -859,7 +859,8 @@ public open class DoubleTensorAlgebra : .slice(matrix.bufferStart until matrix.bufferStart + matrixSize) .toDoubleArray() ) - curMatrix.as2D().svdGolubKahanHelper(uTensors[index].as2D(), sTensorVectors[index], vTensors[index].as2D(), iterations) + curMatrix.as2D().svdGolubKahanHelper(uTensors[index].as2D(), sTensorVectors[index], vTensors[index].as2D(), + iterations, epsilon) } return Triple(uTensor.transpose(), sTensor, vTensor) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt index f6cb82bfd..9724fc335 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/linUtils.kt @@ -373,7 +373,7 @@ private fun SIGN(a: Double, b: Double): Double { return -abs(a) } internal fun MutableStructure2D.svdGolubKahanHelper(u: MutableStructure2D, w: BufferedTensor, - v: MutableStructure2D, iterations: Int) { + v: MutableStructure2D, iterations: Int, epsilon: Double) { val shape = this.shape val m = shape.component1() val n = shape.component2() @@ -384,7 +384,6 @@ internal fun MutableStructure2D.svdGolubKahanHelper(u: MutableStructure2 var anorm = 0.0 var g = 0.0 var l = 0 - val epsilon = 1e-10 val wStart = w.bufferStart val wBuffer = w.mutableBuffer