Move matrix solving and inverting to extensions because of consistency

This commit is contained in:
Iaroslav Postovalov 2020-09-12 09:23:47 +07:00
parent edd3022aac
commit d088fdf77c
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B

View File

@ -11,34 +11,6 @@ import scientifik.kmath.structures.Matrix
* Represents context of basic operations operating with [EjmlMatrix]. * Represents context of basic operations operating with [EjmlMatrix].
*/ */
class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> { class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> {
/**
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix.
*
* @param a the base matrix.
* @param b n by p matrix.
* @return the solution for 'x' that is n by p.
*/
fun solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix.
*
* @param a the base matrix.
* @param b n by p vector.
* @return the solution for 'x' that is n by p.
*/
fun solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Returns the inverse of given matrix: b = a^(-1).
*
* @param a the matrix.
* @return the inverse of this matrix.
*/
fun inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())
/** /**
* Converts this matrix to EJML one. * Converts this matrix to EJML one.
*/ */
@ -53,17 +25,6 @@ class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) } (0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
}) })
override fun unaryOperation(operation: String, arg: Matrix<Double>): Matrix<Double> = when (operation) {
"inverse" -> inverse(arg)
else -> super.unaryOperation(operation, arg)
}
override fun binaryOperation(operation: String, left: Matrix<Double>, right: Matrix<Double>): Matrix<Double> =
when (operation) {
"solve" -> solve(left, right)
else -> super.binaryOperation(operation, left, right)
}
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix = override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix =
EjmlMatrix(SimpleMatrix(rows, columns).also { EjmlMatrix(SimpleMatrix(rows, columns).also {
(0 until it.numRows()).forEach { row -> (0 until it.numRows()).forEach { row ->
@ -90,3 +51,31 @@ class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double
companion object companion object
} }
/**
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix.
*
* @param a the base matrix.
* @param b n by p matrix.
* @return the solution for 'x' that is n by p.
*/
fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix.
*
* @param a the base matrix.
* @param b n by p vector.
* @return the solution for 'x' that is n by p.
*/
fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Returns the inverse of given matrix: b = a^(-1).
*
* @param a the matrix.
* @return the inverse of this matrix.
*/
fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())