From 10f84bd630ae6654fb03898e1fd2d50300976db7 Mon Sep 17 00:00:00 2001 From: Margarita Lashina Date: Wed, 3 May 2023 21:14:29 +0300 Subject: [PATCH] added function solve --- .../kmath/tensors/api/LinearOpsTensorAlgebra.kt | 14 ++++++++++++++ .../kmath/tensors/core/DoubleTensorAlgebra.kt | 6 ++++++ 2 files changed, 20 insertions(+) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index faff2eb80..60865ecba 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -5,8 +5,15 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.nd.as2D import space.kscience.kmath.operations.Field +import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra +import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.dot +import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.map +import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.transposed +import space.kscience.kmath.tensors.core.DoubleTensorAlgebra /** * Common linear algebra operations. Operates on [Tensor]. @@ -103,4 +110,11 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision */ public fun symEig(structureND: StructureND): Pair, StructureND> + /** Returns the solution to the equation Ax = B for the square matrix A as `input1` and + * for the square matrix B as `input2`. + * + * @receiver the `input1` and the `input2`. + * @return the square matrix x which is the solution of the equation. + */ + public fun solve(a: MutableStructure2D, b: MutableStructure2D): MutableStructure2D } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 70a3ef7e2..1571eb7f1 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -711,6 +711,12 @@ public open class DoubleTensorAlgebra : override fun symEig(structureND: StructureND): Pair = symEigJacobi(structureND = structureND, maxIteration = 50, epsilon = 1e-15) + override fun solve(a: MutableStructure2D, b: MutableStructure2D): MutableStructure2D { + val aSvd = DoubleTensorAlgebra.svd(a) + val s = BroadcastDoubleTensorAlgebra.diagonalEmbedding(aSvd.second.map {1.0 / it}) + val aInverse = aSvd.third.dot(s).dot(aSvd.first.transposed()) + return aInverse.dot(b).as2D() + } } public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra