Drop koma module, implement kmath-ejml module copying it, but for EJML SimpleMatrix

This commit is contained in:
Iaroslav Postovalov 2020-09-09 23:42:43 +07:00
parent bea2f2f46d
commit d1184802bd
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
12 changed files with 203 additions and 185 deletions

View File

@ -53,9 +53,6 @@ can be used for a wide variety of purposes from high performance calculations to
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
to submit a feature request if you want something to be done first.
* **Koma wrapper** [Koma](https://github.com/kyonifer/koma) is a well established numerics library in Kotlin, specifically linear algebra.
The plan is to have wrappers for koma implementations for compatibility with kmath API.
## Planned features

View File

@ -12,6 +12,3 @@ api and multiple library back-ends.
* [Expressions](./expressions.md)
* Commons math integration
* Koma integration

View File

@ -29,10 +29,9 @@ dependencies {
implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons"))
implementation(project(":kmath-prob"))
implementation(project(":kmath-koma"))
implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions"))
implementation("com.kyonifer:koma-core-ejml:0.12")
implementation(project(":kmath-ejml"))
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath

View File

@ -1,9 +1,10 @@
package scientifik.kmath.linear
import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.inverse
import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.ejml.EjmlMatrixContext
import scientifik.kmath.ejml.inverse
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix
@ -23,8 +24,8 @@ fun main() {
val n = 5000 // iterations
MatrixContext.real {
repeat(50) { val res = inverse(matrix) }
val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } }
repeat(50) { inverse(matrix) }
val inverseTime = measureTimeMillis { repeat(n) { inverse(matrix) } }
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
}
@ -33,23 +34,19 @@ fun main() {
val commonsTime = measureTimeMillis {
CMMatrixContext {
val cm = matrix.toCM() //avoid overhead on conversion
repeat(n) { val res = inverse(cm) }
repeat(n) { inverse(cm) }
}
}
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis")
//koma-ejml
val komaTime = measureTimeMillis {
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
val km = matrix.toKoma() //avoid overhead on conversion
repeat(n) {
val res = inverse(km)
}
val ejmlTime = measureTimeMillis {
(EjmlMatrixContext(RealField)) {
val km = matrix.toEjml() //avoid overhead on conversion
repeat(n) { inverse(km) }
}
}
println("[koma-ejml] Inversion of $n matrices $dim x $dim finished in $komaTime millis")
}
println("[ejml] Inversion of $n matrices $dim x $dim finished in $ejmlTime millis")
}

View File

@ -1,8 +1,8 @@
package scientifik.kmath.linear
import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.commons.linear.CMMatrixContext
import scientifik.kmath.commons.linear.toCM
import scientifik.kmath.ejml.EjmlMatrixContext
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix
@ -22,28 +22,17 @@ fun main() {
CMMatrixContext {
val cmMatrix1 = matrix1.toCM()
val cmMatrix2 = matrix2.toCM()
val cmTime = measureTimeMillis {
cmMatrix1 dot cmMatrix2
}
val cmTime = measureTimeMillis { cmMatrix1 dot cmMatrix2 }
println("CM implementation time: $cmTime")
}
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
val komaMatrix1 = matrix1.toKoma()
val komaMatrix2 = matrix2.toKoma()
val komaTime = measureTimeMillis {
komaMatrix1 dot komaMatrix2
}
println("Koma-ejml implementation time: $komaTime")
}
val genericTime = measureTimeMillis {
val res = matrix1 dot matrix2
(EjmlMatrixContext(RealField)) {
val ejmlMatrix1 = matrix1.toEjml()
val ejmlMatrix2 = matrix2.toEjml()
val ejmlTime = measureTimeMillis { ejmlMatrix1 dot ejmlMatrix2 }
println("EJML implementation time: $ejmlTime")
}
val genericTime = measureTimeMillis { val res = matrix1 dot matrix2 }
println("Generic implementation time: $genericTime")
}
}

View File

@ -0,0 +1,6 @@
plugins { id("scientifik.jvm") }
dependencies {
implementation("org.ejml:ejml-simple:0.39")
implementation(project(":kmath-core"))
}

View File

@ -0,0 +1,69 @@
package scientifik.kmath.ejml
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix
import scientifik.kmath.linear.DeterminantFeature
import scientifik.kmath.linear.FeaturedMatrix
import scientifik.kmath.linear.LUPDecompositionFeature
import scientifik.kmath.linear.MatrixFeature
import scientifik.kmath.structures.NDStructure
/**
* Represents featured matrix over EJML [SimpleMatrix].
*
* @property origin the underlying [SimpleMatrix].
*/
class EjmlMatrix(val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
override val rowNum: Int
get() = origin.numRows()
override val colNum: Int
get() = origin.numCols()
override val shape: IntArray
get() = intArrayOf(origin.numRows(), origin.numCols())
override val features: Set<MatrixFeature> = features ?: hashSetOf(
object : DeterminantFeature<Double> {
override val determinant: Double
get() = origin.determinant()
},
object : LUPDecompositionFeature<Double> {
private val lup by lazy {
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
.also { it.decompose(origin.ddrm.copy()) }
Triple(
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))),
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))),
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null)))
)
}
override val l: FeaturedMatrix<Double>
get() = lup.second
override val u: FeaturedMatrix<Double>
get() = lup.third
override val p: FeaturedMatrix<Double>
get() = lup.first
}
)
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<Double> =
EjmlMatrix(origin, this.features + features)
override operator fun get(i: Int, j: Int): Double = origin[i, j]
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
}

View File

@ -0,0 +1,75 @@
package scientifik.kmath.ejml
import org.ejml.simple.SimpleMatrix
import scientifik.kmath.linear.MatrixContext
import scientifik.kmath.linear.Point
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix
/**
* Represents context of basic operations operating with [EjmlMatrix].
*/
class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix =
EjmlMatrix(SimpleMatrix(rows, columns).also {
(0 until it.numRows()).forEach { row ->
(0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) }
}
})
fun Matrix<Double>.toEjml(): EjmlMatrix =
if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) }
fun Point<Double>.toEjml(): EjmlVector =
if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also {
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
})
override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector =
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
override fun add(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
EjmlMatrix(a.toEjml().origin + b.toEjml().origin)
override operator fun Matrix<Double>.minus(b: Matrix<Double>): EjmlMatrix =
EjmlMatrix(toEjml().origin - b.toEjml().origin)
override fun multiply(a: Matrix<Double>, k: Number): Matrix<Double> =
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
override operator fun Matrix<Double>.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value))
companion object
}
/**
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix.
*
* @param a the base matrix.
* @param b n by p matrix.
* @return the solution for 'x' that is n by p.
*/
fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix.
*
* @param a the base matrix.
* @param b n by p vector.
* @return the solution for 'x' that is n by p.
*/
fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
/**
* Returns the inverse of given matrix: b = a^(-1).
*
* @param a the matrix.
* @return the inverse of this matrix.
*/
fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())

View File

@ -0,0 +1,30 @@
package scientifik.kmath.ejml
import org.ejml.simple.SimpleMatrix
import scientifik.kmath.linear.Point
/**
* Represents point over EJML [SimpleMatrix].
*
* @property origin the underlying [SimpleMatrix].
*/
class EjmlVector internal constructor(val origin: SimpleMatrix) : Point<Double> {
override val size: Int get() = origin.numRows()
init {
require(origin.numCols() == 1) { error("Only single column matrices are allowed") }
}
override operator fun get(index: Int): Double = origin[index]
override operator fun iterator(): Iterator<Double> = object : Iterator<Double> {
private var cursor: Int = 0
override fun next(): Double {
cursor += 1
return origin[cursor - 1]
}
override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows()
}
}

View File

@ -1,31 +0,0 @@
plugins {
id("scientifik.mpp")
}
repositories {
maven("http://dl.bintray.com/kyonifer/maven")
}
kotlin.sourceSets {
commonMain {
dependencies {
api(project(":kmath-core"))
api("com.kyonifer:koma-core-api-common:0.12")
}
}
jvmMain {
dependencies {
api("com.kyonifer:koma-core-api-jvm:0.12")
}
}
jvmTest {
dependencies {
implementation("com.kyonifer:koma-core-ejml:0.12")
}
}
jsMain {
dependencies {
api("com.kyonifer:koma-core-api-js:0.12")
}
}
}

View File

@ -1,110 +0,0 @@
package scientifik.kmath.linear
import koma.extensions.fill
import koma.matrix.MatrixFactory
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.invoke
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.NDStructure
class KomaMatrixContext<T : Any>(
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
private val space: Space<T>
) : MatrixContext<T> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix<T> =
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
this
} else {
produce(rowNum, colNum) { i, j -> get(i, j) }
}
fun Point<T>.toKoma(): KomaVector<T> = if (this is KomaVector) {
this
} else {
KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) })
}
override fun Matrix<T>.dot(other: Matrix<T>): KomaMatrix<T> =
KomaMatrix(toKoma().origin * other.toKoma().origin)
override fun Matrix<T>.dot(vector: Point<T>): KomaVector<T> =
KomaVector(toKoma().origin * vector.toKoma().origin)
override operator fun Matrix<T>.unaryMinus(): KomaMatrix<T> =
KomaMatrix(toKoma().origin.unaryMinus())
override fun add(a: Matrix<T>, b: Matrix<T>): KomaMatrix<T> =
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
override operator fun Matrix<T>.minus(b: Matrix<T>): KomaMatrix<T> =
KomaMatrix(toKoma().origin - b.toKoma().origin)
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
override operator fun Matrix<T>.times(value: T): KomaMatrix<T> =
KomaMatrix(toKoma().origin * value)
companion object
}
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Point<T>) =
KomaVector(a.toKoma().origin.solve(b.toKoma().origin))
fun <T : Any> KomaMatrixContext<T>.inverse(a: Matrix<T>) =
KomaMatrix(a.toKoma().origin.inv())
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
override val rowNum: Int get() = origin.numRows()
override val colNum: Int get() = origin.numCols()
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
override val features: Set<MatrixFeature> = features ?: hashSetOf(
object : DeterminantFeature<T> {
override val determinant: T get() = origin.det()
},
object : LUPDecompositionFeature<T> {
private val lup by lazy { origin.LU() }
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
}
)
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
KomaMatrix(this.origin, this.features + features)
override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
}
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
override val size: Int get() = origin.numRows()
init {
require(origin.numCols() == 1) { error("Only single column matrices are allowed") }
}
override operator fun get(index: Int): T = origin.getGeneric(index)
override operator fun iterator(): Iterator<T> = origin.toIterable().iterator()
}

View File

@ -40,12 +40,12 @@ include(
":kmath-histograms",
":kmath-commons",
":kmath-viktor",
":kmath-koma",
":kmath-prob",
":kmath-io",
":kmath-dimensions",
":kmath-for-real",
":kmath-geometry",
":kmath-ast",
":examples"
":examples",
":kmath-ejml"
)