forked from kscience/kmath
added function solve
This commit is contained in:
parent
a02085918a
commit
10f84bd630
@ -5,8 +5,15 @@
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
import space.kscience.kmath.nd.MutableStructure2D
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.operations.Field
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.dot
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.map
|
||||
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra.transposed
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
|
||||
/**
|
||||
* Common linear algebra operations. Operates on [Tensor].
|
||||
@ -103,4 +110,11 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
|
||||
*/
|
||||
public fun symEig(structureND: StructureND<T>): Pair<StructureND<T>, StructureND<T>>
|
||||
|
||||
/** Returns the solution to the equation Ax = B for the square matrix A as `input1` and
|
||||
* for the square matrix B as `input2`.
|
||||
*
|
||||
* @receiver the `input1` and the `input2`.
|
||||
* @return the square matrix x which is the solution of the equation.
|
||||
*/
|
||||
public fun solve(a: MutableStructure2D<Double>, b: MutableStructure2D<Double>): MutableStructure2D<Double>
|
||||
}
|
||||
|
@ -711,6 +711,12 @@ public open class DoubleTensorAlgebra :
|
||||
override fun symEig(structureND: StructureND<Double>): Pair<DoubleTensor, DoubleTensor> =
|
||||
symEigJacobi(structureND = structureND, maxIteration = 50, epsilon = 1e-15)
|
||||
|
||||
override fun solve(a: MutableStructure2D<Double>, b: MutableStructure2D<Double>): MutableStructure2D<Double> {
|
||||
val aSvd = DoubleTensorAlgebra.svd(a)
|
||||
val s = BroadcastDoubleTensorAlgebra.diagonalEmbedding(aSvd.second.map {1.0 / it})
|
||||
val aInverse = aSvd.third.dot(s).dot(aSvd.first.transposed())
|
||||
return aInverse.dot(b).as2D()
|
||||
}
|
||||
}
|
||||
|
||||
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra
|
||||
|
Loading…
Reference in New Issue
Block a user