From edd3022aaccb49aed040547ada66c92dd2eeb0eb Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 10 Sep 2020 05:53:44 +0700 Subject: [PATCH] Add dynamic operations and add documentations --- .../scientifik/kmath/linear/MatrixContext.kt | 33 +++++++ .../kmath/ejml/EjmlMatrixContext.kt | 91 +++++++++++-------- 2 files changed, 87 insertions(+), 37 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt index 763bb1615..b7e79f1bc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixContext.kt @@ -18,12 +18,45 @@ interface MatrixContext : SpaceOperations> { */ fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix + override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = when (operation) { + "dot" -> left dot right + else -> super.binaryOperation(operation, left, right) + } + + /** + * Computes the dot product of this matrix and another one. + * + * @receiver the multiplicand. + * @param other the multiplier. + * @return the dot product. + */ infix fun Matrix.dot(other: Matrix): Matrix + /** + * Computes the dot product of this matrix and a vector. + * + * @receiver the multiplicand. + * @param vector the multiplier. + * @return the dot product. + */ infix fun Matrix.dot(vector: Point): Point + /** + * Multiplies a matrix by its element. + * + * @receiver the multiplicand. + * @param value the multiplier. + * @receiver the product. + */ operator fun Matrix.times(value: T): Matrix + /** + * Multiplies an element by a matrix of it. + * + * @receiver the multiplicand. + * @param value the multiplier. + * @receiver the product. + */ operator fun T.times(m: Matrix): Matrix = m * this companion object { diff --git a/kmath-ejml/src/main/kotlin/scientifik/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/scientifik/kmath/ejml/EjmlMatrixContext.kt index 142f1bee3..ee44cd686 100644 --- a/kmath-ejml/src/main/kotlin/scientifik/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/scientifik/kmath/ejml/EjmlMatrixContext.kt @@ -11,6 +11,59 @@ import scientifik.kmath.structures.Matrix * Represents context of basic operations operating with [EjmlMatrix]. */ class EjmlMatrixContext(private val space: Space) : MatrixContext { + /** + * Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix. + * + * @param a the base matrix. + * @param b n by p matrix. + * @return the solution for 'x' that is n by p. + */ + fun solve(a: Matrix, b: Matrix): EjmlMatrix = + EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin)) + + /** + * Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix. + * + * @param a the base matrix. + * @param b n by p vector. + * @return the solution for 'x' that is n by p. + */ + fun solve(a: Matrix, b: Point): EjmlVector = + EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) + + /** + * Returns the inverse of given matrix: b = a^(-1). + * + * @param a the matrix. + * @return the inverse of this matrix. + */ + fun inverse(a: Matrix): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert()) + + /** + * Converts this matrix to EJML one. + */ + fun Matrix.toEjml(): EjmlMatrix = + if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) } + + /** + * Converts this vector to EJML one. + */ + fun Point.toEjml(): EjmlVector = + if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also { + (0 until it.numRows()).forEach { row -> it[row, 0] = get(row) } + }) + + override fun unaryOperation(operation: String, arg: Matrix): Matrix = when (operation) { + "inverse" -> inverse(arg) + else -> super.unaryOperation(operation, arg) + } + + override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = + when (operation) { + "solve" -> solve(left, right) + else -> super.binaryOperation(operation, left, right) + } + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix = EjmlMatrix(SimpleMatrix(rows, columns).also { (0 until it.numRows()).forEach { row -> @@ -18,14 +71,6 @@ class EjmlMatrixContext(private val space: Space) : MatrixContext.toEjml(): EjmlMatrix = - if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) } - - fun Point.toEjml(): EjmlVector = - if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also { - (0 until it.numRows()).forEach { row -> it[row, 0] = get(row) } - }) - override fun Matrix.dot(other: Matrix): EjmlMatrix = EjmlMatrix(toEjml().origin.mult(other.toEjml().origin)) @@ -38,38 +83,10 @@ class EjmlMatrixContext(private val space: Space) : MatrixContext.minus(b: Matrix): EjmlMatrix = EjmlMatrix(toEjml().origin - b.toEjml().origin) - override fun multiply(a: Matrix, k: Number): Matrix = + override fun multiply(a: Matrix, k: Number): EjmlMatrix = produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } override operator fun Matrix.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value)) companion object } - -/** - * Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix. - * - * @param a the base matrix. - * @param b n by p matrix. - * @return the solution for 'x' that is n by p. - */ -fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMatrix = - EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin)) - -/** - * Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix. - * - * @param a the base matrix. - * @param b n by p vector. - * @return the solution for 'x' that is n by p. - */ -fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = - EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) - -/** - * Returns the inverse of given matrix: b = a^(-1). - * - * @param a the matrix. - * @return the inverse of this matrix. - */ -fun EjmlMatrixContext.inverse(a: Matrix): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())