diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt index 526b6781f..38d8c1437 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/LinearSystemSolvingWithLUP.kt @@ -42,8 +42,7 @@ fun main () { // solve `Ax = b` system using LUP decomposition // get P, L, U such that PA = LU - val (lu, pivots) = a.lu() - val (p, l, u) = luPivot(lu, pivots) + val (p, l, u) = a.lu() // check that P is permutation matrix println("P:\n$p") diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index 7a19c5d5a..bcbb52a1b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -60,23 +60,17 @@ public interface LinearOpsTensorAlgebra : public fun TensorStructure.qr(): Pair, TensorStructure> /** - * TODO('Andrew') - * For more information: https://pytorch.org/docs/stable/generated/torch.lu.html + * LUP decomposition * - * @return ... - */ - public fun TensorStructure.lu(): Pair, TensorStructure> - - /** - * TODO('Andrew') - * For more information: https://pytorch.org/docs/stable/generated/torch.lu_unpack.html + * Computes the LUP decomposition of a matrix or a batch of matrices. + * Given a tensor `input`, return tensors (P, L, U) satisfying ``P * input = L * U``, + * with `P` being a permutation matrix or batch of matrices, + * `L` being a lower triangular matrix or batch of matrices, + * `U` being an upper triangular matrix or batch of matrices. * - * @param luTensor ... - * @param pivotsTensor ... - * @return ... + * * @return triple of P, L and U tensors */ - public fun luPivot(luTensor: TensorStructure, pivotsTensor: TensorStructure): - Triple, TensorStructure, TensorStructure> + public fun TensorStructure.lu(): Triple, TensorStructure, TensorStructure> /** * Singular Value Decomposition. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt index 62629f3db..97eed289a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt @@ -29,13 +29,13 @@ public class DoubleLinearOpsTensorAlgebra : override fun TensorStructure.det(): DoubleTensor = detLU(1e-9) - public fun TensorStructure.lu(epsilon: Double): Pair = + public fun TensorStructure.luFactor(epsilon: Double): Pair = computeLU(tensor, epsilon) ?: throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon") - override fun TensorStructure.lu(): Pair = lu(1e-9) + public fun TensorStructure.luFactor(): Pair = luFactor(1e-9) - override fun luPivot( + public fun luPivot( luTensor: TensorStructure, pivotsTensor: TensorStructure ): Triple { @@ -156,7 +156,7 @@ public class DoubleLinearOpsTensorAlgebra : } public fun TensorStructure.invLU(epsilon: Double = 1e-9): DoubleTensor { - val (luTensor, pivotsTensor) = lu(epsilon) + val (luTensor, pivotsTensor) = luFactor(epsilon) val invTensor = luTensor.zeroesLike() val seq = luTensor.matrixSequence().zip(pivotsTensor.vectorSequence()).zip(invTensor.matrixSequence()) @@ -167,6 +167,15 @@ public class DoubleLinearOpsTensorAlgebra : return invTensor } + + public fun TensorStructure.lu(epsilon: Double = 1e-9): Triple { + val (lu, pivots) = this.luFactor(epsilon) + return luPivot(lu, pivots) + } + + override fun TensorStructure.lu(): Triple = lu(1e-9) + + } public inline fun DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R = diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt index 6120f0e4a..1f7e955d4 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleLinearOpsAlgebra.kt @@ -124,9 +124,7 @@ class TestDoubleLinearOpsTensorAlgebra { ) val tensor = fromArray(shape, buffer) - val (lu, pivots) = tensor.lu() - - val (p, l, u) = luPivot(lu, pivots) + val (p, l, u) = tensor.lu() assertTrue { p.shape contentEquals shape } assertTrue { l.shape contentEquals shape }