Reformat, bring back the features of CMMatrix with the new API, add missing features in QRDecomposition in EjmlMatrix

This commit is contained in:
Iaroslav Postovalov 2021-01-20 00:28:39 +07:00
parent 97ec575142
commit 57b1157650
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
4 changed files with 54 additions and 18 deletions

View File

@ -1,11 +1,9 @@
package kscience.kmath.commons.linear package kscience.kmath.commons.linear
import kscience.kmath.linear.DiagonalFeature import kscience.kmath.linear.*
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.MatrixWrapper
import kscience.kmath.linear.Point
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.RealBuffer
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.reflect.cast import kotlin.reflect.cast
@ -17,8 +15,40 @@ public inline class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
@UnstableKMathAPI @UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) { override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null
DeterminantFeature::class, LupDecompositionFeature::class -> object :
DeterminantFeature<Double>,
LupDecompositionFeature<Double> {
private val lup by lazy { LUDecomposition(origin) }
override val determinant: Double by lazy { lup.determinant }
override val l: Matrix<Double> by lazy { CMMatrix(lup.l) + LFeature }
override val u: Matrix<Double> by lazy { CMMatrix(lup.u) + UFeature }
override val p: Matrix<Double> by lazy { CMMatrix(lup.p) }
}
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy {
val cholesky = CholeskyDecomposition(origin)
CMMatrix(cholesky.l) + LFeature
}
}
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy { QRDecomposition(origin) }
override val q: Matrix<Double> by lazy { CMMatrix(qr.q) + OrthogonalFeature }
override val r: Matrix<Double> by lazy { CMMatrix(qr.r) + UFeature }
}
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val sv by lazy { SingularValueDecomposition(origin) }
override val u: Matrix<Double> by lazy { CMMatrix(sv.u) }
override val s: Matrix<Double> by lazy { CMMatrix(sv.s) }
override val v: Matrix<Double> by lazy { CMMatrix(sv.v) }
override val singularValues: Point<Double> by lazy { RealBuffer(sv.singularValues) }
}
else -> null else -> null
}?.let { type.cast(it) } }?.let(type::cast)
public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
} }

View File

@ -11,8 +11,8 @@ public interface MatrixFeature
/** /**
* Matrices with this feature are considered to have only diagonal non-null elements. * Matrices with this feature are considered to have only diagonal non-null elements.
*/ */
public interface DiagonalFeature : MatrixFeature{ public interface DiagonalFeature : MatrixFeature {
public companion object: DiagonalFeature public companion object : DiagonalFeature
} }
/** /**

View File

@ -1,4 +1,4 @@
package kscience.kmath.misc package kscience.kmath.misc
@RequiresOptIn("This API is unstable and could change in future", RequiresOptIn.Level.WARNING) @RequiresOptIn("This API is unstable and could change in future", RequiresOptIn.Level.WARNING)
public annotation class UnstableKMathAPI public annotation class UnstableKMathAPI

View File

@ -15,21 +15,20 @@ import kotlin.reflect.cast
* @property origin the underlying [SimpleMatrix]. * @property origin the underlying [SimpleMatrix].
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public inline class EjmlMatrix( public inline class EjmlMatrix(public val origin: SimpleMatrix) : Matrix<Double> {
public val origin: SimpleMatrix,
) : Matrix<Double> {
public override val rowNum: Int get() = origin.numRows() public override val rowNum: Int get() = origin.numRows()
public override val colNum: Int get() = origin.numCols() public override val colNum: Int get() = origin.numCols()
@UnstableKMathAPI @UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) { public override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> { InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) } override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
} }
DeterminantFeature::class -> object : DeterminantFeature<Double> { DeterminantFeature::class -> object : DeterminantFeature<Double> {
override val determinant: Double by lazy(origin::determinant) override val determinant: Double by lazy(origin::determinant)
} }
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> { SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val svd by lazy { private val svd by lazy {
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false) DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
@ -41,15 +40,20 @@ public inline class EjmlMatrix(
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) } override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) }
} }
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> { QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy { private val qr by lazy {
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) } DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
} }
override val q: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } override val q: Matrix<Double> by lazy {
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) + OrthogonalFeature
}
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) + UFeature }
} }
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy { override val l: Matrix<Double> by lazy {
val cholesky = val cholesky =
DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) }
@ -57,7 +61,8 @@ public inline class EjmlMatrix(
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
} }
} }
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
private val lup by lazy { private val lup by lazy {
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) }
} }
@ -72,8 +77,9 @@ public inline class EjmlMatrix(
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
} }
else -> null else -> null
}?.let{type.cast(it)} }?.let(type::cast)
public override operator fun get(i: Int, j: Int): Double = origin[i, j] public override operator fun get(i: Int, j: Int): Double = origin[i, j]
} }