From 1f19ac88aee34ed5b790356a81f762b21c1752f3 Mon Sep 17 00:00:00 2001 From: Andrei Kislitsyn Date: Fri, 26 Mar 2021 17:43:08 +0300 Subject: [PATCH] qr fix --- .../kmath/tensors/LinearOpsTensorAlgebra.kt | 2 +- .../tensors/core/DoubleLinearOpsTensorAlgebra.kt | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt index 3d1625a50..d980c510f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt @@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra, Inde public fun TensorType.cholesky(): TensorType //https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr - public fun TensorType.qr(): TensorType + public fun TensorType.qr(): Pair //https://pytorch.org/docs/stable/generated/torch.lu.html public fun TensorType.lu(): Pair diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index 0b9d824f0..d49fc941a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -1,8 +1,10 @@ package space.kscience.kmath.tensors.core +import space.kscience.kmath.linear.Matrix import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.Structure1D import space.kscience.kmath.nd.Structure2D +import space.kscience.kmath.nd.asND import space.kscience.kmath.tensors.LinearOpsTensorAlgebra import kotlin.math.sqrt @@ -141,7 +143,15 @@ public class DoubleLinearOpsTensorAlgebra : return lTensor } - override fun DoubleTensor.qr(): DoubleTensor { + private fun matrixQR( + matrix: Structure2D, + q: MutableStructure2D, + r: MutableStructure2D + ) { + + } + + override fun DoubleTensor.qr(): Pair { TODO("ANDREI") } @@ -154,6 +164,7 @@ public class DoubleLinearOpsTensorAlgebra : } private fun luMatrixDet(lu: Structure2D, pivots: Structure1D): Double { + // todo check val m = lu.shape[0] val sign = if((pivots[m] - m) % 2 == 0) 1.0 else -1.0 return (0 until m).asSequence().map { lu[it, it] }.fold(sign) { left, right -> left * right } @@ -184,6 +195,7 @@ public class DoubleLinearOpsTensorAlgebra : pivots: Structure1D, invMatrix : MutableStructure2D ): Unit { + //todo check val m = lu.shape[0] for (j in 0 until m) {