Drop koma module, implement kmath-ejml module copying it, but for EJML SimpleMatrix
This commit is contained in:
parent
bea2f2f46d
commit
d1184802bd
@ -53,9 +53,6 @@ can be used for a wide variety of purposes from high performance calculations to
|
|||||||
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
* **Commons-math wrapper** It is planned to gradually wrap most parts of [Apache commons-math](http://commons.apache.org/proper/commons-math/)
|
||||||
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
library in Kotlin code and maybe rewrite some parts to better suit the Kotlin programming paradigm, however there is no fixed roadmap for that. Feel free
|
||||||
to submit a feature request if you want something to be done first.
|
to submit a feature request if you want something to be done first.
|
||||||
|
|
||||||
* **Koma wrapper** [Koma](https://github.com/kyonifer/koma) is a well established numerics library in Kotlin, specifically linear algebra.
|
|
||||||
The plan is to have wrappers for koma implementations for compatibility with kmath API.
|
|
||||||
|
|
||||||
## Planned features
|
## Planned features
|
||||||
|
|
||||||
|
@ -12,6 +12,3 @@ api and multiple library back-ends.
|
|||||||
* [Expressions](./expressions.md)
|
* [Expressions](./expressions.md)
|
||||||
|
|
||||||
* Commons math integration
|
* Commons math integration
|
||||||
|
|
||||||
* Koma integration
|
|
||||||
|
|
||||||
|
@ -29,10 +29,9 @@ dependencies {
|
|||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
implementation(project(":kmath-prob"))
|
implementation(project(":kmath-prob"))
|
||||||
implementation(project(":kmath-koma"))
|
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
implementation(project(":kmath-ejml"))
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-8")
|
||||||
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import koma.matrix.ejml.EJMLMatrixFactory
|
|
||||||
import scientifik.kmath.commons.linear.CMMatrixContext
|
import scientifik.kmath.commons.linear.CMMatrixContext
|
||||||
import scientifik.kmath.commons.linear.inverse
|
import scientifik.kmath.commons.linear.inverse
|
||||||
import scientifik.kmath.commons.linear.toCM
|
import scientifik.kmath.commons.linear.toCM
|
||||||
|
import scientifik.kmath.ejml.EjmlMatrixContext
|
||||||
|
import scientifik.kmath.ejml.inverse
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.invoke
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
@ -23,8 +24,8 @@ fun main() {
|
|||||||
val n = 5000 // iterations
|
val n = 5000 // iterations
|
||||||
|
|
||||||
MatrixContext.real {
|
MatrixContext.real {
|
||||||
repeat(50) { val res = inverse(matrix) }
|
repeat(50) { inverse(matrix) }
|
||||||
val inverseTime = measureTimeMillis { repeat(n) { val res = inverse(matrix) } }
|
val inverseTime = measureTimeMillis { repeat(n) { inverse(matrix) } }
|
||||||
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
println("[kmath] Inversion of $n matrices $dim x $dim finished in $inverseTime millis")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,23 +34,19 @@ fun main() {
|
|||||||
val commonsTime = measureTimeMillis {
|
val commonsTime = measureTimeMillis {
|
||||||
CMMatrixContext {
|
CMMatrixContext {
|
||||||
val cm = matrix.toCM() //avoid overhead on conversion
|
val cm = matrix.toCM() //avoid overhead on conversion
|
||||||
repeat(n) { val res = inverse(cm) }
|
repeat(n) { inverse(cm) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis")
|
println("[commons-math] Inversion of $n matrices $dim x $dim finished in $commonsTime millis")
|
||||||
|
|
||||||
//koma-ejml
|
val ejmlTime = measureTimeMillis {
|
||||||
|
(EjmlMatrixContext(RealField)) {
|
||||||
val komaTime = measureTimeMillis {
|
val km = matrix.toEjml() //avoid overhead on conversion
|
||||||
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
|
repeat(n) { inverse(km) }
|
||||||
val km = matrix.toKoma() //avoid overhead on conversion
|
|
||||||
repeat(n) {
|
|
||||||
val res = inverse(km)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println("[koma-ejml] Inversion of $n matrices $dim x $dim finished in $komaTime millis")
|
println("[ejml] Inversion of $n matrices $dim x $dim finished in $ejmlTime millis")
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import koma.matrix.ejml.EJMLMatrixFactory
|
|
||||||
import scientifik.kmath.commons.linear.CMMatrixContext
|
import scientifik.kmath.commons.linear.CMMatrixContext
|
||||||
import scientifik.kmath.commons.linear.toCM
|
import scientifik.kmath.commons.linear.toCM
|
||||||
|
import scientifik.kmath.ejml.EjmlMatrixContext
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.invoke
|
import scientifik.kmath.operations.invoke
|
||||||
import scientifik.kmath.structures.Matrix
|
import scientifik.kmath.structures.Matrix
|
||||||
@ -22,28 +22,17 @@ fun main() {
|
|||||||
CMMatrixContext {
|
CMMatrixContext {
|
||||||
val cmMatrix1 = matrix1.toCM()
|
val cmMatrix1 = matrix1.toCM()
|
||||||
val cmMatrix2 = matrix2.toCM()
|
val cmMatrix2 = matrix2.toCM()
|
||||||
|
val cmTime = measureTimeMillis { cmMatrix1 dot cmMatrix2 }
|
||||||
val cmTime = measureTimeMillis {
|
|
||||||
cmMatrix1 dot cmMatrix2
|
|
||||||
}
|
|
||||||
|
|
||||||
println("CM implementation time: $cmTime")
|
println("CM implementation time: $cmTime")
|
||||||
}
|
}
|
||||||
|
|
||||||
(KomaMatrixContext(EJMLMatrixFactory(), RealField)) {
|
(EjmlMatrixContext(RealField)) {
|
||||||
val komaMatrix1 = matrix1.toKoma()
|
val ejmlMatrix1 = matrix1.toEjml()
|
||||||
val komaMatrix2 = matrix2.toKoma()
|
val ejmlMatrix2 = matrix2.toEjml()
|
||||||
|
val ejmlTime = measureTimeMillis { ejmlMatrix1 dot ejmlMatrix2 }
|
||||||
val komaTime = measureTimeMillis {
|
println("EJML implementation time: $ejmlTime")
|
||||||
komaMatrix1 dot komaMatrix2
|
|
||||||
}
|
|
||||||
|
|
||||||
println("Koma-ejml implementation time: $komaTime")
|
|
||||||
}
|
|
||||||
|
|
||||||
val genericTime = measureTimeMillis {
|
|
||||||
val res = matrix1 dot matrix2
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val genericTime = measureTimeMillis { val res = matrix1 dot matrix2 }
|
||||||
println("Generic implementation time: $genericTime")
|
println("Generic implementation time: $genericTime")
|
||||||
}
|
}
|
||||||
|
6
kmath-ejml/build.gradle.kts
Normal file
6
kmath-ejml/build.gradle.kts
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
plugins { id("scientifik.jvm") }
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("org.ejml:ejml-simple:0.39")
|
||||||
|
implementation(project(":kmath-core"))
|
||||||
|
}
|
@ -0,0 +1,69 @@
|
|||||||
|
package scientifik.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import scientifik.kmath.linear.DeterminantFeature
|
||||||
|
import scientifik.kmath.linear.FeaturedMatrix
|
||||||
|
import scientifik.kmath.linear.LUPDecompositionFeature
|
||||||
|
import scientifik.kmath.linear.MatrixFeature
|
||||||
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents featured matrix over EJML [SimpleMatrix].
|
||||||
|
*
|
||||||
|
* @property origin the underlying [SimpleMatrix].
|
||||||
|
*/
|
||||||
|
class EjmlMatrix(val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||||
|
override val rowNum: Int
|
||||||
|
get() = origin.numRows()
|
||||||
|
|
||||||
|
override val colNum: Int
|
||||||
|
get() = origin.numCols()
|
||||||
|
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||||
|
|
||||||
|
override val features: Set<MatrixFeature> = features ?: hashSetOf(
|
||||||
|
object : DeterminantFeature<Double> {
|
||||||
|
override val determinant: Double
|
||||||
|
get() = origin.determinant()
|
||||||
|
},
|
||||||
|
|
||||||
|
object : LUPDecompositionFeature<Double> {
|
||||||
|
private val lup by lazy {
|
||||||
|
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
|
||||||
|
.also { it.decompose(origin.ddrm.copy()) }
|
||||||
|
|
||||||
|
Triple(
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))),
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))),
|
||||||
|
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null)))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val l: FeaturedMatrix<Double>
|
||||||
|
get() = lup.second
|
||||||
|
|
||||||
|
override val u: FeaturedMatrix<Double>
|
||||||
|
get() = lup.third
|
||||||
|
|
||||||
|
override val p: FeaturedMatrix<Double>
|
||||||
|
get() = lup.first
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<Double> =
|
||||||
|
EjmlMatrix(origin, this.features + features)
|
||||||
|
|
||||||
|
override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = origin.hashCode()
|
||||||
|
result = 31 * result + features.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
package scientifik.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import scientifik.kmath.linear.MatrixContext
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
import scientifik.kmath.operations.Space
|
||||||
|
import scientifik.kmath.operations.invoke
|
||||||
|
import scientifik.kmath.structures.Matrix
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents context of basic operations operating with [EjmlMatrix].
|
||||||
|
*/
|
||||||
|
class EjmlMatrixContext(private val space: Space<Double>) : MatrixContext<Double> {
|
||||||
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix =
|
||||||
|
EjmlMatrix(SimpleMatrix(rows, columns).also {
|
||||||
|
(0 until it.numRows()).forEach { row ->
|
||||||
|
(0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||||
|
if (this is EjmlMatrix) this else produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||||
|
|
||||||
|
fun Point<Double>.toEjml(): EjmlVector =
|
||||||
|
if (this is EjmlVector) this else EjmlVector(SimpleMatrix(size, 1).also {
|
||||||
|
(0 until it.numRows()).forEach { row -> it[row, 0] = get(row) }
|
||||||
|
})
|
||||||
|
|
||||||
|
override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
|
||||||
|
|
||||||
|
override fun Matrix<Double>.dot(vector: Point<Double>): EjmlVector =
|
||||||
|
EjmlVector(toEjml().origin.mult(vector.toEjml().origin))
|
||||||
|
|
||||||
|
override fun add(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(a.toEjml().origin + b.toEjml().origin)
|
||||||
|
|
||||||
|
override operator fun Matrix<Double>.minus(b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(toEjml().origin - b.toEjml().origin)
|
||||||
|
|
||||||
|
override fun multiply(a: Matrix<Double>, k: Number): Matrix<Double> =
|
||||||
|
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
|
||||||
|
|
||||||
|
override operator fun Matrix<Double>.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value))
|
||||||
|
|
||||||
|
companion object
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves for X in the following equation: x = a^-1*b, where 'a' is base matrix and 'b' is an n by p matrix.
|
||||||
|
*
|
||||||
|
* @param a the base matrix.
|
||||||
|
* @param b n by p matrix.
|
||||||
|
* @return the solution for 'x' that is n by p.
|
||||||
|
*/
|
||||||
|
fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMatrix =
|
||||||
|
EjmlMatrix(a.toEjml().origin.solve(b.toEjml().origin))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves for X in the following equation: x = a^(-1)*b, where 'a' is base matrix and 'b' is an n by p matrix.
|
||||||
|
*
|
||||||
|
* @param a the base matrix.
|
||||||
|
* @param b n by p vector.
|
||||||
|
* @return the solution for 'x' that is n by p.
|
||||||
|
*/
|
||||||
|
fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
|
||||||
|
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the inverse of given matrix: b = a^(-1).
|
||||||
|
*
|
||||||
|
* @param a the matrix.
|
||||||
|
* @return the inverse of this matrix.
|
||||||
|
*/
|
||||||
|
fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())
|
@ -0,0 +1,30 @@
|
|||||||
|
package scientifik.kmath.ejml
|
||||||
|
|
||||||
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents point over EJML [SimpleMatrix].
|
||||||
|
*
|
||||||
|
* @property origin the underlying [SimpleMatrix].
|
||||||
|
*/
|
||||||
|
class EjmlVector internal constructor(val origin: SimpleMatrix) : Point<Double> {
|
||||||
|
override val size: Int get() = origin.numRows()
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(origin.numCols() == 1) { error("Only single column matrices are allowed") }
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun get(index: Int): Double = origin[index]
|
||||||
|
|
||||||
|
override operator fun iterator(): Iterator<Double> = object : Iterator<Double> {
|
||||||
|
private var cursor: Int = 0
|
||||||
|
|
||||||
|
override fun next(): Double {
|
||||||
|
cursor += 1
|
||||||
|
return origin[cursor - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = cursor < origin.numCols() * origin.numRows()
|
||||||
|
}
|
||||||
|
}
|
@ -1,31 +0,0 @@
|
|||||||
plugins {
|
|
||||||
id("scientifik.mpp")
|
|
||||||
}
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
maven("http://dl.bintray.com/kyonifer/maven")
|
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
|
||||||
commonMain {
|
|
||||||
dependencies {
|
|
||||||
api(project(":kmath-core"))
|
|
||||||
api("com.kyonifer:koma-core-api-common:0.12")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jvmMain {
|
|
||||||
dependencies {
|
|
||||||
api("com.kyonifer:koma-core-api-jvm:0.12")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jvmTest {
|
|
||||||
dependencies {
|
|
||||||
implementation("com.kyonifer:koma-core-ejml:0.12")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jsMain {
|
|
||||||
dependencies {
|
|
||||||
api("com.kyonifer:koma-core-api-js:0.12")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,110 +0,0 @@
|
|||||||
package scientifik.kmath.linear
|
|
||||||
|
|
||||||
import koma.extensions.fill
|
|
||||||
import koma.matrix.MatrixFactory
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
import scientifik.kmath.operations.invoke
|
|
||||||
import scientifik.kmath.structures.Matrix
|
|
||||||
import scientifik.kmath.structures.NDStructure
|
|
||||||
|
|
||||||
class KomaMatrixContext<T : Any>(
|
|
||||||
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
|
||||||
private val space: Space<T>
|
|
||||||
) : MatrixContext<T> {
|
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix<T> =
|
|
||||||
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
|
||||||
|
|
||||||
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
|
|
||||||
this
|
|
||||||
} else {
|
|
||||||
produce(rowNum, colNum) { i, j -> get(i, j) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fun Point<T>.toKoma(): KomaVector<T> = if (this is KomaVector) {
|
|
||||||
this
|
|
||||||
} else {
|
|
||||||
KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) })
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
override fun Matrix<T>.dot(other: Matrix<T>): KomaMatrix<T> =
|
|
||||||
KomaMatrix(toKoma().origin * other.toKoma().origin)
|
|
||||||
|
|
||||||
override fun Matrix<T>.dot(vector: Point<T>): KomaVector<T> =
|
|
||||||
KomaVector(toKoma().origin * vector.toKoma().origin)
|
|
||||||
|
|
||||||
override operator fun Matrix<T>.unaryMinus(): KomaMatrix<T> =
|
|
||||||
KomaMatrix(toKoma().origin.unaryMinus())
|
|
||||||
|
|
||||||
override fun add(a: Matrix<T>, b: Matrix<T>): KomaMatrix<T> =
|
|
||||||
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
|
|
||||||
|
|
||||||
override operator fun Matrix<T>.minus(b: Matrix<T>): KomaMatrix<T> =
|
|
||||||
KomaMatrix(toKoma().origin - b.toKoma().origin)
|
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
|
||||||
produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } }
|
|
||||||
|
|
||||||
override operator fun Matrix<T>.times(value: T): KomaMatrix<T> =
|
|
||||||
KomaMatrix(toKoma().origin * value)
|
|
||||||
|
|
||||||
companion object
|
|
||||||
}
|
|
||||||
|
|
||||||
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
|
|
||||||
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
|
||||||
|
|
||||||
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Point<T>) =
|
|
||||||
KomaVector(a.toKoma().origin.solve(b.toKoma().origin))
|
|
||||||
|
|
||||||
fun <T : Any> KomaMatrixContext<T>.inverse(a: Matrix<T>) =
|
|
||||||
KomaMatrix(a.toKoma().origin.inv())
|
|
||||||
|
|
||||||
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
|
|
||||||
override val rowNum: Int get() = origin.numRows()
|
|
||||||
override val colNum: Int get() = origin.numCols()
|
|
||||||
|
|
||||||
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
|
|
||||||
|
|
||||||
override val features: Set<MatrixFeature> = features ?: hashSetOf(
|
|
||||||
object : DeterminantFeature<T> {
|
|
||||||
override val determinant: T get() = origin.det()
|
|
||||||
},
|
|
||||||
|
|
||||||
object : LUPDecompositionFeature<T> {
|
|
||||||
private val lup by lazy { origin.LU() }
|
|
||||||
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
|
||||||
override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
|
|
||||||
override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
|
||||||
KomaMatrix(this.origin, this.features + features)
|
|
||||||
|
|
||||||
override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
var result = origin.hashCode()
|
|
||||||
result = 31 * result + features.hashCode()
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
|
||||||
override val size: Int get() = origin.numRows()
|
|
||||||
|
|
||||||
init {
|
|
||||||
require(origin.numCols() == 1) { error("Only single column matrices are allowed") }
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun get(index: Int): T = origin.getGeneric(index)
|
|
||||||
override operator fun iterator(): Iterator<T> = origin.toIterable().iterator()
|
|
||||||
}
|
|
@ -40,12 +40,12 @@ include(
|
|||||||
":kmath-histograms",
|
":kmath-histograms",
|
||||||
":kmath-commons",
|
":kmath-commons",
|
||||||
":kmath-viktor",
|
":kmath-viktor",
|
||||||
":kmath-koma",
|
|
||||||
":kmath-prob",
|
":kmath-prob",
|
||||||
":kmath-io",
|
":kmath-io",
|
||||||
":kmath-dimensions",
|
":kmath-dimensions",
|
||||||
":kmath-for-real",
|
":kmath-for-real",
|
||||||
":kmath-geometry",
|
":kmath-geometry",
|
||||||
":kmath-ast",
|
":kmath-ast",
|
||||||
":examples"
|
":examples",
|
||||||
|
":kmath-ejml"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user