Merge pull request #136 from mipt-npm/ejml
Drop koma module, implement kmath-ejml module copying it, but for EJML SimpleMatrix
This commit is contained in:
commit
0e6448cd3e
@ -7,6 +7,7 @@
|
|||||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||||
- Automatic README generation for features (#139)
|
- Automatic README generation for features (#139)
|
||||||
- Native support for `memory`, `core` and `dimensions`
|
- Native support for `memory`, `core` and `dimensions`
|
||||||
|
- `kmath-ejml` to supply EJML SimpleMatrix wrapper.
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
- Package changed from `scientifik` to `kscience.kmath`.
|
- Package changed from `scientifik` to `kscience.kmath`.
|
||||||
|
@ -53,7 +53,9 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
|
* **EJML wrapper** Provides EJML `SimpleMatrix` wrapper consistent with the core matrix structures.
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
* **Messaging** A mathematical notation to support multi-language and multi-node communication for mathematical tasks.
|
||||||
|
@ -26,6 +26,7 @@ dependencies {
|
|||||||
implementation(project(":kmath-prob"))
|
implementation(project(":kmath-prob"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
|
implementation(project(":kmath-ejml"))
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
||||||
implementation("org.slf4j:slf4j-simple:1.7.30")
|
implementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
|
@ -0,0 +1,50 @@
|
|||||||
|
package kscience.kmath.linear
|
||||||
|
|
||||||
|
import kscience.kmath.commons.linear.CMMatrixContext
|
||||||
|
import kscience.kmath.commons.linear.inverse
|
||||||
|
import kscience.kmath.commons.linear.toCM
|
||||||
|
import kscience.kmath.ejml.EjmlMatrixContext
|
||||||
|
import kscience.kmath.ejml.inverse
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val random = Random(1224)
|
||||||
|
val dim = 100
|
||||||
|
//creating invertible matrix
|
||||||
|
val u = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
val l = Matrix.real(dim, dim) { i, j -> if (i >= j) random.nextDouble() else 0.0 }
|
||||||
|
val matrix = l dot u
|
||||||
|
|
||||||
|
val n = 5000 // iterations
|
||||||
|
|
||||||
|
MatrixContext.real {
|
||||||
|
repeat(50) { inverse(matrix) }
|
||||||
|
val inverseTime = measureTimeMillis { repeat(n) { inverse(matrix) } }
|
||||||
|
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||||
|
}
|
||||||
|
|
||||||
|
//commons-math
|
||||||
|
|
||||||
|
val commonsTime = measureTimeMillis {
|
||||||
|
CMMatrixContext {
|
||||||
|
val cm = matrix.toCM() //avoid overhead on conversion
|
||||||
|
repeat(n) { inverse(cm) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis")
|
||||||
|
|
||||||
|
val ejmlTime = measureTimeMillis {
|
||||||
|
(EjmlMatrixContext(RealField)) {
|
||||||
|
val km = matrix.toEjml() //avoid overhead on conversion
|
||||||
|
repeat(n) { inverse(km) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println("[ejml] Inversion of $n matrices $dim x $dim finished in $ejmlTime millis")
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
package kscience.kmath.linear
|
||||||
|
|
||||||
|
import kscience.kmath.commons.linear.CMMatrixContext
|
||||||
|
import kscience.kmath.commons.linear.toCM
|
||||||
|
import kscience.kmath.ejml.EjmlMatrixContext
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val random = Random(12224)
|
||||||
|
val dim = 1000
|
||||||
|
//creating invertible matrix
|
||||||
|
val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 }
|
||||||
|
|
||||||
|
// //warmup
|
||||||
|
// matrix1 dot matrix2
|
||||||
|
|
||||||
|
CMMatrixContext {
|
||||||
|
val cmMatrix1 = matrix1.toCM()
|
||||||
|
val cmMatrix2 = matrix2.toCM()
|
||||||
|
val cmTime = measureTimeMillis { cmMatrix1 dot cmMatrix2 }
|
||||||
|
println("CM implementation time: $cmTime")
|
||||||
|
}
|
||||||
|
|
||||||
|
(EjmlMatrixContext(RealField)) {
|
||||||
|
val ejmlMatrix1 = matrix1.toEjml()
|
||||||
|
val ejmlMatrix2 = matrix2.toEjml()
|
||||||
|
val ejmlTime = measureTimeMillis { ejmlMatrix1 dot ejmlMatrix2 }
|
||||||
|
println("EJML implementation time: $ejmlTime")
|
||||||
|
}
|
||||||
|
|
||||||
|
val genericTime = measureTimeMillis { val res = matrix1 dot matrix2 }
|
||||||
|
println("Generic implementation time: $genericTime")
|
||||||
|
}
|
@ -18,20 +18,52 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
|||||||
*/
|
*/
|
||||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||||
|
|
||||||
|
public override fun binaryOperation(operation: String, left: Matrix<T>, right: Matrix<T>): Matrix<T> = when (operation) {
|
||||||
|
"dot" -> left dot right
|
||||||
|
else -> super.binaryOperation(operation, left, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the dot product of this matrix and another one.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param other the multiplier.
|
||||||
|
* @return the dot product.
|
||||||
|
*/
|
||||||
public infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
public infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the dot product of this matrix and a vector.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param vector the multiplier.
|
||||||
|
* @return the dot product.
|
||||||
|
*/
|
||||||
public infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
public infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies a matrix by its element.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param value the multiplier.
|
||||||
|
* @receiver the product.
|
||||||
|
*/
|
||||||
public operator fun Matrix<T>.times(value: T): Matrix<T>
|
public operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies an element by a matrix of it.
|
||||||
|
*
|
||||||
|
* @receiver the multiplicand.
|
||||||
|
* @param value the multiplier.
|
||||||
|
* @receiver the product.
|
||||||
|
*/
|
||||||
public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* Non-boxing double matrix
|
* Non-boxing double matrix
|
||||||
*/
|
*/
|
||||||
public val real: RealMatrixContext
|
public val real: RealMatrixContext = RealMatrixContext
|
||||||
get() = RealMatrixContext
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structured matrix with custom buffer
|
* A structured matrix with custom buffer
|
||||||
@ -60,7 +92,7 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
*/
|
*/
|
||||||
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
public override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||||
|
|
||||||
@ -71,7 +103,7 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
public override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||||
|
|
||||||
@ -81,10 +113,10 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
public override operator fun Matrix<T>.unaryMinus(): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
|
produce(rowNum, colNum) { i, j -> elementContext { -get(i, j) } }
|
||||||
|
|
||||||
override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
public override fun add(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
|
require(a.rowNum == b.rowNum && a.colNum == b.colNum) {
|
||||||
"Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
|
"Matrix operation dimension mismatch. [${a.rowNum},${a.colNum}] + [${b.rowNum},${b.colNum}]"
|
||||||
}
|
}
|
||||||
@ -92,7 +124,7 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
|
return produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
public override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
require(rowNum == b.rowNum && colNum == b.colNum) {
|
require(rowNum == b.rowNum && colNum == b.colNum) {
|
||||||
"Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
|
"Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]"
|
||||||
}
|
}
|
||||||
@ -100,11 +132,11 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
|||||||
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
|
return produce(rowNum, colNum) { i, j -> elementContext { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
public override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
||||||
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
||||||
|
|
||||||
public operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
public operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||||
|
|
||||||
override operator fun Matrix<T>.times(value: T): Matrix<T> =
|
public override operator fun Matrix<T>.times(value: T): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
||||||
}
|
}
|
||||||
|
8
kmath-ejml/build.gradle.kts
Normal file
8
kmath-ejml/build.gradle.kts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("org.ejml:ejml-simple:0.39")
|
||||||
|
implementation(project(":kmath-core"))
|
||||||
|
}
|
71
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt
Normal file
71
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kscience.kmath.linear.DeterminantFeature
|
||||||
|
import kscience.kmath.linear.FeaturedMatrix
|
||||||
|
import kscience.kmath.linear.LUPDecompositionFeature
|
||||||
|
import kscience.kmath.linear.MatrixFeature
|
||||||
|
import kscience.kmath.structures.NDStructure
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents featured matrix over EJML [SimpleMatrix].
|
||||||
|
*
|
||||||
|
* @property origin the underlying [SimpleMatrix].
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public class EjmlMatrix(public val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||||
|
public override val rowNum: Int
|
||||||
|
get() = origin.numRows()
|
||||||
|
|
||||||
|
public override val colNum: Int
|
||||||
|
get() = origin.numCols()
|
||||||
|
|
||||||
|
public override val shape: IntArray
|
||||||
|
get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||||
|
|
||||||
|
public override val features: Set<MatrixFeature> = setOf(
|
||||||
|
object : LUPDecompositionFeature<Double>, DeterminantFeature<Double> {
|
||||||
|
override val determinant: Double
|
||||||
|
get() = origin.determinant()
|
||||||
|
|
||||||
|
private val lup by lazy {
|
||||||
|
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
|
||||||
|
.also { it.decompose(origin.ddrm.copy()) }
|
||||||
|
|
||||||
|
Triple(
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))),
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))),
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val l: FeaturedMatrix<Double>
|
||||||
|
get() = lup.second
|
||||||
|
|
||||||
|
override val u: FeaturedMatrix<Double>
|
||||||
|
get() = lup.third
|
||||||
|
|
||||||
|
override val p: FeaturedMatrix<Double>
|
||||||
|
get() = lup.first
|
||||||
|
}
|
||||||
|
) union features.orEmpty()
|
||||||
|
|
||||||
|
public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix =
|
||||||
|
EjmlMatrix(origin, this.features + features)
|
||||||
|
|
||||||
|
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||||
|
|
||||||
|
public override fun equals(other: Any?): Boolean {
|
||||||
|
if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0)
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun hashCode(): Int {
|
||||||
|
var result = origin.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)"
|
||||||
|
}
|
@ -0,0 +1,86 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kscience.kmath.linear.MatrixContext
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.operations.Space
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.structures.Matrix
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents context of basic operations operating with [EjmlMatrix].
|
||||||
|
*
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> {
|
||||||
|
/**
|
||||||
|
* Converts this matrix to EJML one.
|
||||||
|
*/
|
||||||
|
public fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||||
|
if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts this vector to EJML one.
|
||||||
|
*/
|
||||||
|
public fun Point<Double>.toEjml(): EjmlVector =
|
||||||
|
if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also {
|
||||||
|
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
|
||||||
|
})
|
||||||
|
|
||||||
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix =
|
||||||
|
EjmlMatrix(SimpleMatrix(rows, columns).also {
|
||||||
|
(0 until it.numRows()).forEach { row ->
|
||||||
|
(0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
|
||||||
|
|
||||||
|
public override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector =
|
||||||
|
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
|
||||||
|
|
||||||
|
public override fun add(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(a.toEjml().origin + b.toEjml().origin)
|
||||||
|
|
||||||
|
public override operator fun Matrix<Double>.minus(b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(toEjml().origin - b.toEjml().origin)
|
||||||
|
|
||||||
|
public override fun multiply(a: Matrix<Double>, k: Number): EjmlMatrix =
|
||||||
|
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
|
||||||
|
|
||||||
|
public override operator fun Matrix<Double>.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value))
|
||||||
|
|
||||||
|
public companion object
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix.
|
||||||
|
*
|
||||||
|
* @param a the base matrix.
|
||||||
|
* @param b n by p matrix.
|
||||||
|
* @return the solution for 'x' that is n by p.
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix.
|
||||||
|
*
|
||||||
|
* @param a the base matrix.
|
||||||
|
* @param b n by p vector.
|
||||||
|
* @return the solution for 'x' that is n by p.
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
|
||||||
|
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the inverse of given matrix: b = a^(-1).
|
||||||
|
*
|
||||||
|
* @param a the matrix.
|
||||||
|
* @return the inverse of this matrix.
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())
|
40
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt
Normal file
40
kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlVector.kt
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
import kscience.kmath.structures.Buffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents point over EJML [SimpleMatrix].
|
||||||
|
*
|
||||||
|
* @property origin the underlying [SimpleMatrix].
|
||||||
|
* @author Iaroslav Postovalov
|
||||||
|
*/
|
||||||
|
public class EjmlVector internal constructor(public val origin: SimpleMatrix) : Point<Double> {
|
||||||
|
public override val size: Int
|
||||||
|
get() = origin.numRows()
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(origin.numCols() == 1) { "Only single column matrices are allowed" }
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun get(index: Int): Double = origin[index]
|
||||||
|
|
||||||
|
public override operator fun iterator(): Iterator<Double> = object : Iterator<Double> {
|
||||||
|
private var cursor: Int = 0
|
||||||
|
|
||||||
|
override fun next(): Double {
|
||||||
|
cursor += 1
|
||||||
|
return origin[cursor - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun contentEquals(other: Buffer<*>): Boolean {
|
||||||
|
if (other is EjmlVector) return origin.isIdentical(other.origin, 0.0)
|
||||||
|
return super.contentEquals(other)
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun toString(): String = "EjmlVector(origin=$origin)"
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import kscience.kmath.linear.DeterminantFeature
|
||||||
|
import kscience.kmath.linear.LUPDecompositionFeature
|
||||||
|
import kscience.kmath.linear.MatrixFeature
|
||||||
|
import kscience.kmath.linear.getFeature
|
||||||
|
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.random.asJavaRandom
|
||||||
|
import kotlin.test.*
|
||||||
|
|
||||||
|
internal class EjmlMatrixTest {
|
||||||
|
private val random = Random(0)
|
||||||
|
|
||||||
|
private val randomMatrix: SimpleMatrix
|
||||||
|
get() {
|
||||||
|
val s = random.nextInt(2, 100)
|
||||||
|
return SimpleMatrix.random_DDRM(s, s, 0.0, 10.0, random.asJavaRandom())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun rowNum() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertEquals(m.numRows(), EjmlMatrix(m).rowNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun colNum() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertEquals(m.numCols(), EjmlMatrix(m).rowNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun shape() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlMatrix(m)
|
||||||
|
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun features() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlMatrix(m)
|
||||||
|
val det = w.getFeature<DeterminantFeature<Double>>() ?: fail()
|
||||||
|
assertEquals(m.determinant(), det.determinant)
|
||||||
|
val lup = w.getFeature<LUPDecompositionFeature<Double>>() ?: fail()
|
||||||
|
|
||||||
|
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
|
||||||
|
.also { it.decompose(m.ddrm.copy()) }
|
||||||
|
|
||||||
|
assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), lup.l)
|
||||||
|
assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), lup.u)
|
||||||
|
assertEquals(EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), lup.p)
|
||||||
|
}
|
||||||
|
|
||||||
|
private object SomeFeature : MatrixFeature {}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun suggestFeature() {
|
||||||
|
assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature<SomeFeature>())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun get() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertEquals(m[0, 0], EjmlMatrix(m)[0, 0])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun origin() {
|
||||||
|
val m = randomMatrix
|
||||||
|
assertSame(m, EjmlMatrix(m).origin)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,47 @@
|
|||||||
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.random.asJavaRandom
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.assertSame
|
||||||
|
|
||||||
|
internal class EjmlVectorTest {
|
||||||
|
private val random = Random(0)
|
||||||
|
|
||||||
|
private val randomMatrix: SimpleMatrix
|
||||||
|
get() = SimpleMatrix.random_DDRM(random.nextInt(2, 100), 1, 0.0, 10.0, random.asJavaRandom())
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun size() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
assertEquals(m.numRows(), w.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun get() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
assertEquals(m[0, 0], w[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun iterator() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
|
||||||
|
assertEquals(
|
||||||
|
m.iterator(true, 0, 0, m.numRows() - 1, 0).asSequence().toList(),
|
||||||
|
w.iterator().asSequence().toList()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun origin() {
|
||||||
|
val m = randomMatrix
|
||||||
|
val w = EjmlVector(m)
|
||||||
|
assertSame(m, w.origin)
|
||||||
|
}
|
||||||
|
}
|
@ -39,5 +39,6 @@ include(
|
|||||||
":kmath-for-real",
|
":kmath-for-real",
|
||||||
":kmath-geometry",
|
":kmath-geometry",
|
||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":examples"
|
":examples",
|
||||||
|
":kmath-ejml"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user