diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/margarita.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/margarita.kt index 3629b6a47..c376d6d0f 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/margarita.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/margarita.kt @@ -11,28 +11,9 @@ import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.as2D import space.kscience.kmath.tensors.core.* -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.dot -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.mapIndexed -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.zeros -import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.minus -import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.sum import space.kscience.kmath.tensors.core.tensorAlgebra import kotlin.math.* -fun DoubleArray.fmap(transform: (Double) -> Double): DoubleArray { - return this.map(transform).toDoubleArray() -} - -fun scalarProduct(v1: Structure2D, v2: Structure2D): Double { - return v1.mapIndexed { index, d -> d * v2[index] }.sum() -} - -internal fun diagonal(shape: IntArray, v: Double) : DoubleTensor { - val matrix = zeros(shape) - return matrix.mapIndexed { index, _ -> if (index.component1() == index.component2()) v else 0.0 } -} - - fun MutableStructure2D.print() { val n = this.shape.component1() val m = this.shape.component2() @@ -63,11 +44,22 @@ fun main(): Unit = Double.tensorAlgebra.withBroadcast { ) val tensor = fromArray(shape, buffer).as2D() val v = fromArray(intArrayOf(3, 3), buffer2).as2D() + val w_shape = intArrayOf(3, 1) + var w_buffer = doubleArrayOf(0.000000) + for (i in 0 until 3 - 1) { + w_buffer += doubleArrayOf(0.000000) + } + val w = BroadcastDoubleTensorAlgebra.fromArray(w_shape, w_buffer).as2D() tensor.print() - tensor.svdcmp(v) - - + var ans = Pair(w, v) + tensor.svdGolabKahan(v, w) + println("u") + tensor.print() + println("w") + w.print() + println("v") + v.print() } diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/svdcmp.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/svdcmp.kt index e7a424d74..8572bdee9 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/svdcmp.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/svdcmp.kt @@ -1,7 +1,6 @@ package space.kscience.kmath.tensors import space.kscience.kmath.nd.* -import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra import kotlin.math.abs import kotlin.math.max import kotlin.math.min @@ -34,10 +33,12 @@ fun SIGN(a: Double, b: Double): Double { return -abs(a) } -internal fun MutableStructure2D.svdcmp(v: MutableStructure2D) { +// matrix v is not transposed at the output + +internal fun MutableStructure2D.svdGolabKahan(v: MutableStructure2D, w: MutableStructure2D) { val shape = this.shape - val n = shape.component2() val m = shape.component1() + val n = shape.component2() var f = 0.0 val rv1 = DoubleArray(n) var s = 0.0 @@ -45,12 +46,6 @@ internal fun MutableStructure2D.svdcmp(v: MutableStructure2D) { var anorm = 0.0 var g = 0.0 var l = 0 - val w_shape = intArrayOf(n, 1) - var w_buffer = doubleArrayOf(0.000000) - for (i in 0 until n - 1) { - w_buffer += doubleArrayOf(0.000000) - } - val w = BroadcastDoubleTensorAlgebra.fromArray(w_shape, w_buffer).as2D() for (i in 0 until n) { /* left-hand reduction */ l = i + 1 @@ -212,6 +207,9 @@ internal fun MutableStructure2D.svdcmp(v: MutableStructure2D) { this[3, 2] = -0.297540 this[4, 2] = 0.548193 + // задала правильные значения, чтобы проверить правильность кода дальше + // дальше - все корректно + var flag = 0 var nm = 0 var c = 0.0 @@ -268,9 +266,10 @@ internal fun MutableStructure2D.svdcmp(v: MutableStructure2D) { break } - if (its == 30) { - return - } +// надо придумать, что сделать - выкинуть ошибку? +// if (its == 30) { +// return +// } x = w[l, 0] nm = k - 1 @@ -326,11 +325,4 @@ internal fun MutableStructure2D.svdcmp(v: MutableStructure2D) { w[k, 0] = x } } - - println("u") - this.print() - println("w") - w.print() - println("v") - v.print() } \ No newline at end of file