Golub-Kahan SVD algorithm for KMP tensors #499
@ -837,7 +837,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
return this.svdGolubKahan()
|
return this.svdGolubKahan()
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun StructureND<Double>.svdGolubKahan(iterations: Int = 30): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
public fun StructureND<Double>.svdGolubKahan(iterations: Int = 30, epsilon: Double = 1e-10): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
val size = tensor.dimension
|
val size = tensor.dimension
|
||||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||||
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
||||||
@ -859,7 +859,8 @@ public open class DoubleTensorAlgebra :
|
|||||||
.slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
|
.slice(matrix.bufferStart until matrix.bufferStart + matrixSize)
|
||||||
.toDoubleArray()
|
.toDoubleArray()
|
||||||
)
|
)
|
||||||
curMatrix.as2D().svdGolubKahanHelper(uTensors[index].as2D(), sTensorVectors[index], vTensors[index].as2D(), iterations)
|
curMatrix.as2D().svdGolubKahanHelper(uTensors[index].as2D(), sTensorVectors[index], vTensors[index].as2D(),
|
||||||
|
iterations, epsilon)
|
||||||
}
|
}
|
||||||
|
|
||||||
return Triple(uTensor.transpose(), sTensor, vTensor)
|
return Triple(uTensor.transpose(), sTensor, vTensor)
|
||||||
|
@ -373,7 +373,7 @@ private fun SIGN(a: Double, b: Double): Double {
|
|||||||
return -abs(a)
|
return -abs(a)
|
||||||
}
|
}
|
||||||
internal fun MutableStructure2D<Double>.svdGolubKahanHelper(u: MutableStructure2D<Double>, w: BufferedTensor<Double>,
|
internal fun MutableStructure2D<Double>.svdGolubKahanHelper(u: MutableStructure2D<Double>, w: BufferedTensor<Double>,
|
||||||
v: MutableStructure2D<Double>, iterations: Int) {
|
v: MutableStructure2D<Double>, iterations: Int, epsilon: Double) {
|
||||||
val shape = this.shape
|
val shape = this.shape
|
||||||
val m = shape.component1()
|
val m = shape.component1()
|
||||||
val n = shape.component2()
|
val n = shape.component2()
|
||||||
@ -384,7 +384,6 @@ internal fun MutableStructure2D<Double>.svdGolubKahanHelper(u: MutableStructure2
|
|||||||
var anorm = 0.0
|
var anorm = 0.0
|
||||||
var g = 0.0
|
var g = 0.0
|
||||||
var l = 0
|
var l = 0
|
||||||
val epsilon = 1e-10
|
|
||||||
|
|
||||||
val wStart = w.bufferStart
|
val wStart = w.bufferStart
|
||||||
val wBuffer = w.mutableBuffer
|
val wBuffer = w.mutableBuffer
|
||||||
|
Loading…
Reference in New Issue
Block a user