diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt index 3d1625a50..d980c510f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt @@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra, Inde public fun TensorType.cholesky(): TensorType //https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr - public fun TensorType.qr(): TensorType + public fun TensorType.qr(): Pair //https://pytorch.org/docs/stable/generated/torch.lu.html public fun TensorType.lu(): Pair diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index b81cf0364..46ab5f80f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -42,7 +42,8 @@ public class DoubleLinearOpsTensorAlgebra : //todo checks checkSquareMatrix(luTensor.shape) check( - luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape + luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() || + luTensor.shape.last() == pivotsTensor.shape.last() - 1 ) { "Bed shapes ((" } //todo rewrite val n = luTensor.shape.last() @@ -76,8 +77,56 @@ public class DoubleLinearOpsTensorAlgebra : return lTensor } - override fun DoubleTensor.qr(): DoubleTensor { - TODO("ANDREI") + private fun MutableStructure1D.dot(other: MutableStructure1D): Double { + var res = 0.0 + for (i in 0 until size) { + res += this[i] * other[i] + } + return res + } + + private fun MutableStructure1D.l2Norm(): Double { + var squareSum = 0.0 + for (i in 0 until size) { + squareSum += this[i] * this[i] + } + return sqrt(squareSum) + } + + fun qrHelper( + matrix: MutableStructure2D, + q: MutableStructure2D, + r: MutableStructure2D + ) { + //todo check square + val n = matrix.colNum + for (j in 0 until n) { + val v = matrix.columns[j] + if (j > 0) { + for (i in 0 until j) { + r[i, j] = q.columns[i].dot(matrix.columns[j]) + for (k in 0 until n) { + v[k] = v[k] - r[i, j] * q.columns[i][k] + } + } + } + r[j, j] = v.l2Norm() + for (i in 0 until n) { + q[i, j] = v[i] / r[j, j] + } + } + } + + override fun DoubleTensor.qr(): Pair { + checkSquareMatrix(shape) + val qTensor = zeroesLike() + val rTensor = zeroesLike() + val seq = matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence()))) + for ((matrix, qr) in seq) { + val (q, r) = qr + qrHelper(matrix.as2D(), q.as2D(), r.as2D()) + } + return Pair(qTensor, rTensor) } override fun DoubleTensor.svd(): Triple { diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index d2d4ffb67..b79c54dd1 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -76,4 +76,50 @@ class TestDoubleLinearOpsTensorAlgebra { val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4)) assertEquals(a.dot(b).value(), 59.92) } + + @Test + fun testQR() = DoubleLinearOpsTensorAlgebra { + val shape = intArrayOf(2, 2, 2) + val buffer = doubleArrayOf( + 1.0, 3.0, + 1.0, 2.0, + 1.5, 1.0, + 10.0, 2.0 + ) + + val tensor = fromArray(shape, buffer) + + val (q, r) = tensor.qr() + + assertTrue { q.shape contentEquals shape } + assertTrue { r.shape contentEquals shape } + + assertTrue { q.dot(r).buffer.array().epsEqual(buffer) } + + //todo check orthogonality/upper triang. + } + + @Test + fun testLU() = DoubleLinearOpsTensorAlgebra { + val shape = intArrayOf(2, 2, 2) + val buffer = doubleArrayOf( + 1.0, 3.0, + 1.0, 2.0, + 1.5, 1.0, + 10.0, 2.0 + ) + val tensor = fromArray(shape, buffer) + + val (lu, pivots) = tensor.lu() + + // todo check lu + + val (p, l, u) = luPivot(lu, pivots) + + assertTrue { p.shape contentEquals shape } + assertTrue { l.shape contentEquals shape } + assertTrue { u.shape contentEquals shape } + + assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) } + } }