added function solve

This commit is contained in:
Margarita Lashina 2023-05-03 21:14:29 +03:00
parent a02085918a
commit 10f84bd630
2 changed files with 20 additions and 0 deletions

View File

@ -5,8 +5,15 @@
package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.Field
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.dot
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.map
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.transposed
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
/**
* Common linear algebra operations. Operates on [Tensor].
@ -103,4 +110,11 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
*/
public fun symEig(structureND: StructureND<T>): Pair<StructureND<T>, StructureND<T>>
/** Returns the solution to the equation Ax = B for the square matrix A as `input1` and
* for the square matrix B as `input2`.
*
* @receiver the `input1` and the `input2`.
* @return the square matrix x which is the solution of the equation.
*/
public fun solve(a: MutableStructure2D<Double>, b: MutableStructure2D<Double>): MutableStructure2D<Double>
}

View File

@ -711,6 +711,12 @@ public open class DoubleTensorAlgebra :
override fun symEig(structureND: StructureND<Double>): Pair<DoubleTensor, DoubleTensor> =
symEigJacobi(structureND = structureND, maxIteration = 50, epsilon = 1e-15)
override fun solve(a: MutableStructure2D<Double>, b: MutableStructure2D<Double>): MutableStructure2D<Double> {
val aSvd = DoubleTensorAlgebra.svd(a)
val s = BroadcastDoubleTensorAlgebra.diagonalEmbedding(aSvd.second.map {1.0 / it})
val aInverse = aSvd.third.dot(s).dot(aSvd.first.transposed())
return aInverse.dot(b).as2D()
}
}
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra