Move features to scopes

This commit is contained in:
Alexander Nozik 2021-03-15 16:46:12 +03:00
parent 8c098b6033
commit 71ed56c2b6
6 changed files with 156 additions and 124 deletions

View File

@ -13,44 +13,6 @@ public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
public override val rowNum: Int get() = origin.rowDimension
public override val colNum: Int get() = origin.columnDimension
@UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null
DeterminantFeature::class, LupDecompositionFeature::class -> object :
DeterminantFeature<Double>,
LupDecompositionFeature<Double> {
private val lup by lazy { LUDecomposition(origin) }
override val determinant: Double by lazy { lup.determinant }
override val l: Matrix<Double> by lazy { CMMatrix(lup.l) + LFeature }
override val u: Matrix<Double> by lazy { CMMatrix(lup.u) + UFeature }
override val p: Matrix<Double> by lazy { CMMatrix(lup.p) }
}
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy {
val cholesky = CholeskyDecomposition(origin)
CMMatrix(cholesky.l) + LFeature
}
}
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy { QRDecomposition(origin) }
override val q: Matrix<Double> by lazy { CMMatrix(qr.q) + OrthogonalFeature }
override val r: Matrix<Double> by lazy { CMMatrix(qr.r) + UFeature }
}
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val sv by lazy { SingularValueDecomposition(origin) }
override val u: Matrix<Double> by lazy { CMMatrix(sv.u) }
override val s: Matrix<Double> by lazy { CMMatrix(sv.s) }
override val v: Matrix<Double> by lazy { CMMatrix(sv.v) }
override val singularValues: Point<Double> by lazy { RealBuffer(sv.singularValues) }
}
else -> null
}?.let(type::cast)
public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
override fun equals(other: Any?): Boolean {
@ -134,6 +96,50 @@ public object CMLinearSpace : LinearSpace<Double, RealField> {
override fun Double.times(v: Point<Double>): CMVector =
v * this
@UnstableKMathAPI
override fun <F : Any> getFeature(structure: Matrix<Double>, type: KClass<F>): F? {
//Return the feature if it is intrinsic to the structure
structure.getFeature(type)?.let { return it }
val origin = structure.toCM().origin
return when (type) {
DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null
DeterminantFeature::class, LupDecompositionFeature::class -> object :
DeterminantFeature<Double>,
LupDecompositionFeature<Double> {
private val lup by lazy { LUDecomposition(origin) }
override val determinant: Double by lazy { lup.determinant }
override val l: Matrix<Double> by lazy { CMMatrix(lup.l) + LFeature }
override val u: Matrix<Double> by lazy { CMMatrix(lup.u) + UFeature }
override val p: Matrix<Double> by lazy { CMMatrix(lup.p) }
}
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy {
val cholesky = CholeskyDecomposition(origin)
CMMatrix(cholesky.l) + LFeature
}
}
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy { QRDecomposition(origin) }
override val q: Matrix<Double> by lazy { CMMatrix(qr.q) + OrthogonalFeature }
override val r: Matrix<Double> by lazy { CMMatrix(qr.r) + UFeature }
}
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val sv by lazy { SingularValueDecomposition(origin) }
override val u: Matrix<Double> by lazy { CMMatrix(sv.u) }
override val s: Matrix<Double> by lazy { CMMatrix(sv.s) }
override val v: Matrix<Double> by lazy { CMMatrix(sv.v) }
override val singularValues: Point<Double> by lazy { RealBuffer(sv.singularValues) }
}
else -> null
}?.let(type::cast)
}
}
public operator fun CMMatrix.plus(other: CMMatrix): CMMatrix =

View File

@ -156,16 +156,15 @@ public interface LinearSpace<T : Any, out A : Ring<T>> {
public operator fun T.times(v: Point<T>): Point<T> = v * this
/**
* Gets a feature from the matrix. This function may return some additional features to
* [space.kscience.kmath.nd.NDStructure.getFeature].
* Get a feature of the structure in this scope. Structure features take precedence other context features
*
* @param F the type of feature.
* @param m the matrix.
* @param structure the structure.
* @param type the [KClass] instance of [F].
* @return a feature object or `null` if it isn't present.
*/
@UnstableKMathAPI
public fun <F : Any> getFeature(m: Matrix<T>, type: KClass<F>): F? = m.getFeature(type)
public fun <F : Any> getFeature(structure: Matrix<T>, type: KClass<F>): F? = structure.getFeature(type)
public companion object {
@ -187,19 +186,17 @@ public interface LinearSpace<T : Any, out A : Ring<T>> {
}
}
public operator fun <LS : LinearSpace<*, *>, R> LS.invoke(block: LS.() -> R): R = run(block)
/**
* Gets a feature from the matrix. This function may return some additional features to
* [space.kscience.kmath.nd.NDStructure.getFeature].
* Get a feature of the structure in this scope. Structure features take precedence other context features
*
* @param T the type of items in the matrices.
* @param M the type of operated matrices.
* @param F the type of feature.
* @receiver the [LinearSpace] of [T].
* @param m the matrix.
* @return a feature object or `null` if it isn't present.
*/
@UnstableKMathAPI
public inline fun <T : Any, reified F : Any> LinearSpace<T, *>.getFeature(m: Matrix<T>): F? = getFeature(m, F::class)
public inline fun <T : Any, reified F : Any> LinearSpace<T, *>.getFeature(structure: Matrix<T>): F? =
getFeature(structure, F::class)
public operator fun <LS : LinearSpace<*, *>, R> LS.invoke(block: LS.() -> R): R = run(block)

View File

@ -1,7 +1,9 @@
package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
import kotlin.reflect.KClass
/**
* An exception is thrown when the expected ans actual shape of NDArray differs.
@ -56,9 +58,32 @@ public interface NDAlgebra<T, C : Algebra<T>> {
public operator fun Function1<T, T>.invoke(structure: NDStructure<T>): NDStructure<T> =
structure.map { value -> this@invoke(value) }
/**
* Get a feature of the structure in this scope. Structure features take precedence other context features
*
* @param F the type of feature.
* @param structure the structure.
* @param type the [KClass] instance of [F].
* @return a feature object or `null` if it isn't present.
*/
@UnstableKMathAPI
public fun <F : Any> getFeature(structure: NDStructure<T>, type: KClass<F>): F? = structure.getFeature(type)
public companion object
}
/**
* Get a feature of the structure in this scope. Structure features take precedence other context features
*
* @param T the type of items in the matrices.
* @param F the type of feature.
* @return a feature object or `null` if it isn't present.
*/
@UnstableKMathAPI
public inline fun <T : Any, reified F : Any> NDAlgebra<T, *>.getFeature(structure: NDStructure<T>): F? =
getFeature(structure, F::class)
/**
* Checks if given elements are consistent with this context.
*

View File

@ -1,10 +1,14 @@
package space.kscience.kmath.ejml
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix
import space.kscience.kmath.linear.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.getFeature
import space.kscience.kmath.operations.RealField
import space.kscience.kmath.structures.RealBuffer
import kotlin.reflect.KClass
import kotlin.reflect.cast
/**
* Represents context of basic operations operating with [EjmlMatrix].
@ -83,6 +87,76 @@ public object EjmlLinearSpace : LinearSpace<Double, RealField> {
override fun Double.times(v: Point<Double>): EjmlVector =
v.toEjml().origin.scale(this).wrapVector()
@UnstableKMathAPI
override fun <F : Any> getFeature(structure: Matrix<Double>, type: KClass<F>): F? {
//Return the feature if it is intrinsic to the structure
structure.getFeature(type)?.let { return it }
val origin = structure.toEjml().origin
return when (type) {
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
}
DeterminantFeature::class -> object : DeterminantFeature<Double> {
override val determinant: Double by lazy(origin::determinant)
}
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val svd by lazy {
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
.apply { decompose(origin.ddrm.copy()) }
}
override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) }
}
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy {
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
}
override val q: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) + OrthogonalFeature
}
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) + UFeature }
}
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy {
val cholesky =
DecompositionFactory_DDRM.chol(structure.rowNum, true).apply { decompose(origin.ddrm.copy()) }
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
}
}
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
private val lup by lazy {
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
.apply { decompose(origin.ddrm.copy()) }
}
override val l: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
}
override val u: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
}
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
}
else -> null
}?.let(type::cast)
}
}
/**

View File

@ -1,13 +1,8 @@
package space.kscience.kmath.ejml
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix
import space.kscience.kmath.linear.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.nd.NDStructure
import space.kscience.kmath.structures.RealBuffer
import kotlin.reflect.KClass
import kotlin.reflect.cast
/**
* Represents featured matrix over EJML [SimpleMatrix].
@ -19,68 +14,6 @@ public class EjmlMatrix(public val origin: SimpleMatrix) : Matrix<Double> {
public override val rowNum: Int get() = origin.numRows()
public override val colNum: Int get() = origin.numCols()
@UnstableKMathAPI
public override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
}
DeterminantFeature::class -> object : DeterminantFeature<Double> {
override val determinant: Double by lazy(origin::determinant)
}
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
private val svd by lazy {
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
.apply { decompose(origin.ddrm.copy()) }
}
override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) }
}
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy {
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
}
override val q: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) + OrthogonalFeature
}
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) + UFeature }
}
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
override val l: Matrix<Double> by lazy {
val cholesky =
DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) }
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
}
}
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
private val lup by lazy {
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) }
}
override val l: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
}
override val u: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
}
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
}
else -> null
}?.let(type::cast)
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
override fun equals(other: Any?): Boolean {

View File

@ -2,10 +2,7 @@ package space.kscience.kmath.ejml
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix
import space.kscience.kmath.linear.DeterminantFeature
import space.kscience.kmath.linear.LupDecompositionFeature
import space.kscience.kmath.linear.MatrixFeature
import space.kscience.kmath.linear.plus
import space.kscience.kmath.linear.*
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.getFeature
import kotlin.random.Random
@ -45,9 +42,9 @@ internal class EjmlMatrixTest {
fun features() {
val m = randomMatrix
val w = EjmlMatrix(m)
val det = w.getFeature<DeterminantFeature<Double>>() ?: fail()
val det: DeterminantFeature<Double> = EjmlLinearSpace.getFeature(w) ?: fail()
assertEquals(m.determinant(), det.determinant)
val lup = w.getFeature<LupDecompositionFeature<Double>>() ?: fail()
val lup: LupDecompositionFeature<Double> = EjmlLinearSpace.getFeature(w) ?: fail()
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
.also { it.decompose(m.ddrm.copy()) }