Features refactoring.

This commit is contained in:
Alexander Nozik 2021-01-19 19:32:13 +03:00
parent ab32cd9561
commit 4c256a9f14
19 changed files with 236 additions and 244 deletions

View File

@ -33,6 +33,7 @@
- EjmlMatrix context is an object - EjmlMatrix context is an object
- Matrix LUP `inverse` renamed to `inverseWithLUP` - Matrix LUP `inverse` renamed to `inverseWithLUP`
- `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it). - `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it).
- Features moved to NDStructure and became transparent.
### Deprecated ### Deprecated

View File

@ -1,42 +1,28 @@
package kscience.kmath.commons.linear package kscience.kmath.commons.linear
import kscience.kmath.linear.* import kscience.kmath.linear.DiagonalFeature
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.MatrixWrapper
import kscience.kmath.linear.Point
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.NDStructure
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
import kotlin.reflect.KClass
import kotlin.reflect.cast
public class CMMatrix(public val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> { public inline class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
public override val rowNum: Int get() = origin.rowDimension public override val rowNum: Int get() = origin.rowDimension
public override val colNum: Int get() = origin.columnDimension public override val colNum: Int get() = origin.columnDimension
public override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> { @UnstableKMathAPI
if (origin is DiagonalMatrix) yield(DiagonalFeature) override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
}.toHashSet() DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null
else -> null
public override fun suggestFeature(vararg features: MatrixFeature): CMMatrix = }?.let { type.cast(it) }
CMMatrix(origin, this.features + features)
public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
public override fun equals(other: Any?): Boolean {
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
public override fun hashCode(): Int {
var result = origin.hashCode()
result = 31 * result + features.hashCode()
return result
}
} }
//TODO move inside context
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
this
} else {
//TODO add feature analysis
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
CMMatrix(Array2DRowRealMatrix(array))
}
public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
@ -61,6 +47,16 @@ public object CMMatrixContext : MatrixContext<Double, CMMatrix> {
return CMMatrix(Array2DRowRealMatrix(array)) return CMMatrix(Array2DRowRealMatrix(array))
} }
public fun Matrix<Double>.toCM(): CMMatrix = when {
this is CMMatrix -> this
this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix
else -> {
//TODO add feature analysis
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
CMMatrix(Array2DRowRealMatrix(array))
}
}
public override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix = public override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
CMMatrix(toCM().origin.multiply(other.toCM().origin)) CMMatrix(toCM().origin.multiply(other.toCM().origin))

View File

@ -1,10 +1,7 @@
package kscience.kmath.linear package kscience.kmath.linear
import kscience.kmath.operations.Ring import kscience.kmath.operations.Ring
import kscience.kmath.structures.Buffer import kscience.kmath.structures.*
import kscience.kmath.structures.BufferFactory
import kscience.kmath.structures.NDStructure
import kscience.kmath.structures.asSequence
/** /**
* Basic implementation of Matrix space based on [NDStructure] * Basic implementation of Matrix space based on [NDStructure]
@ -27,8 +24,7 @@ public class BufferMatrix<T : Any>(
public override val rowNum: Int, public override val rowNum: Int,
public override val colNum: Int, public override val colNum: Int,
public val buffer: Buffer<out T>, public val buffer: Buffer<out T>,
public override val features: Set<MatrixFeature> = emptySet(), ) : Matrix<T> {
) : FeaturedMatrix<T> {
init { init {
require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" } require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" }
@ -36,9 +32,6 @@ public class BufferMatrix<T : Any>(
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
BufferMatrix(rowNum, colNum, buffer, this.features + features)
public override operator fun get(index: IntArray): T = get(index[0], index[1]) public override operator fun get(index: IntArray): T = get(index[0], index[1])
public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
@ -50,23 +43,26 @@ public class BufferMatrix<T : Any>(
if (this === other) return true if (this === other) return true
return when (other) { return when (other) {
is NDStructure<*> -> return NDStructure.equals(this, other) is NDStructure<*> -> NDStructure.equals(this, other)
else -> false else -> false
} }
} }
public override fun hashCode(): Int { override fun hashCode(): Int {
var result = buffer.hashCode() var result = rowNum
result = 31 * result + features.hashCode() result = 31 * result + colNum
result = 31 * result + buffer.hashCode()
return result return result
} }
public override fun toString(): String { public override fun toString(): String {
return if (rowNum <= 5 && colNum <= 5) return if (rowNum <= 5 && colNum <= 5)
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + "Matrix(rowsNum = $rowNum, colNum = $colNum)\n" +
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
buffer.asSequence().joinToString(separator = "\t") { it.toString() } buffer.asSequence().joinToString(separator = "\t") { it.toString() }
} }
else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" else "Matrix(rowsNum = $rowNum, colNum = $colNum)"
} }
} }

View File

@ -1,13 +1,14 @@
package kscience.kmath.linear package kscience.kmath.linear
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.operations.* import kscience.kmath.operations.*
import kscience.kmath.structures.* import kscience.kmath.structures.*
/** /**
* Common implementation of [LupDecompositionFeature]. * Common implementation of [LupDecompositionFeature].
*/ */
public class LUPDecomposition<T : Any>( public class LupDecomposition<T : Any>(
public val context: MatrixContext<T, FeaturedMatrix<T>>, public val context: MatrixContext<T, Matrix<T>>,
public val elementContext: Field<T>, public val elementContext: Field<T>,
public val lu: Matrix<T>, public val lu: Matrix<T>,
public val pivot: IntArray, public val pivot: IntArray,
@ -18,13 +19,13 @@ public class LUPDecomposition<T : Any>(
* *
* L is a lower-triangular matrix with [Ring.one] in diagonal * L is a lower-triangular matrix with [Ring.one] in diagonal
*/ */
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
when { when {
j < i -> lu[i, j] j < i -> lu[i, j]
j == i -> elementContext.one j == i -> elementContext.one
else -> elementContext.zero else -> elementContext.zero
} }
} } + LFeature
/** /**
@ -32,9 +33,9 @@ public class LUPDecomposition<T : Any>(
* *
* U is an upper-triangular matrix including the diagonal * U is an upper-triangular matrix including the diagonal
*/ */
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
if (j >= i) lu[i, j] else elementContext.zero if (j >= i) lu[i, j] else elementContext.zero
} } + UFeature
/** /**
* Returns the P rows permutation matrix. * Returns the P rows permutation matrix.
@ -42,7 +43,7 @@ public class LUPDecomposition<T : Any>(
* P is a sparse matrix with exactly one element set to [Ring.one] in * P is a sparse matrix with exactly one element set to [Ring.one] in
* each row and each column, all other elements being set to [Ring.zero]. * each row and each column, all other elements being set to [Ring.zero].
*/ */
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
if (j == pivot[i]) elementContext.one else elementContext.zero if (j == pivot[i]) elementContext.one else elementContext.zero
} }
@ -63,12 +64,12 @@ internal fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, *>.abs
/** /**
* Create a lup decomposition of generic matrix. * Create a lup decomposition of generic matrix.
*/ */
public fun <T : Comparable<T>> MatrixContext<T, FeaturedMatrix<T>>.lup( public fun <T : Comparable<T>> MatrixContext<T, Matrix<T>>.lup(
factory: MutableBufferFactory<T>, factory: MutableBufferFactory<T>,
elementContext: Field<T>, elementContext: Field<T>,
matrix: Matrix<T>, matrix: Matrix<T>,
checkSingular: (T) -> Boolean, checkSingular: (T) -> Boolean,
): LUPDecomposition<T> { ): LupDecomposition<T> {
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
val m = matrix.colNum val m = matrix.colNum
val pivot = IntArray(matrix.rowNum) val pivot = IntArray(matrix.rowNum)
@ -137,23 +138,23 @@ public fun <T : Comparable<T>> MatrixContext<T, FeaturedMatrix<T>>.lup(
for (row in col + 1 until m) lu[row, col] /= luDiag for (row in col + 1 until m) lu[row, col] /= luDiag
} }
return LUPDecomposition(this@lup, elementContext, lu.collect(), pivot, even) return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even)
} }
} }
} }
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.lup( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.lup(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline checkSingular: (T) -> Boolean, noinline checkSingular: (T) -> Boolean,
): LUPDecomposition<T> = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular) ): LupDecomposition<T> = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular)
public fun MatrixContext<Double, FeaturedMatrix<Double>>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> = public fun MatrixContext<Double, Matrix<Double>>.lup(matrix: Matrix<Double>): LupDecomposition<Double> =
lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 } lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 }
public fun <T : Any> LUPDecomposition<T>.solveWithLUP( public fun <T : Any> LupDecomposition<T>.solveWithLUP(
factory: MutableBufferFactory<T>, factory: MutableBufferFactory<T>,
matrix: Matrix<T> matrix: Matrix<T>,
): FeaturedMatrix<T> { ): Matrix<T> {
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run {
@ -198,25 +199,40 @@ public fun <T : Any> LUPDecomposition<T>.solveWithLUP(
} }
} }
public inline fun <reified T : Any> LUPDecomposition<T>.solveWithLUP(matrix: Matrix<T>): Matrix<T> = public inline fun <reified T : Any> LupDecomposition<T>.solveWithLUP(matrix: Matrix<T>): Matrix<T> =
solveWithLUP(MutableBuffer.Companion::auto, matrix) solveWithLUP(MutableBuffer.Companion::auto, matrix)
/** /**
* Solve a linear equation **a*x = b** using LUP decomposition * Solve a linear equation **a*x = b** using LUP decomposition
*/ */
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.solveWithLUP( @OptIn(UnstableKMathAPI::class)
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.solveWithLUP(
a: Matrix<T>, a: Matrix<T>,
b: Matrix<T>, b: Matrix<T>,
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto, noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
noinline checkSingular: (T) -> Boolean, noinline checkSingular: (T) -> Boolean,
): FeaturedMatrix<T> { ): Matrix<T> {
// Use existing decomposition if it is provided by matrix // Use existing decomposition if it is provided by matrix
val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular) val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular)
return decomposition.solveWithLUP(bufferFactory, b) return decomposition.solveWithLUP(bufferFactory, b)
} }
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP( public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.inverseWithLUP(
matrix: Matrix<T>, matrix: Matrix<T>,
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto, noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
noinline checkSingular: (T) -> Boolean, noinline checkSingular: (T) -> Boolean,
): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) ): Matrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
// Use existing decomposition if it is provided by matrix
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
val decomposition: LupDecomposition<Double> = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
return decomposition.solveWithLUP(bufferFactory, b)
}
/**
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
*/
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): Matrix<Double> =
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum))

View File

@ -1,12 +1,9 @@
package kscience.kmath.linear package kscience.kmath.linear
import kscience.kmath.structures.Buffer import kscience.kmath.structures.*
import kscience.kmath.structures.BufferFactory
import kscience.kmath.structures.Structure2D
import kscience.kmath.structures.asBuffer
public class MatrixBuilder(public val rows: Int, public val columns: Int) { public class MatrixBuilder(public val rows: Int, public val columns: Int) {
public operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> { public operator fun <T : Any> invoke(vararg elements: T): Matrix<T> {
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
@ -17,7 +14,7 @@ public class MatrixBuilder(public val rows: Int, public val columns: Int) {
public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns)
public fun <T : Any> Structure2D.Companion.row(vararg values: T): FeaturedMatrix<T> { public fun <T : Any> Structure2D.Companion.row(vararg values: T): Matrix<T> {
val buffer = values.asBuffer() val buffer = values.asBuffer()
return BufferMatrix(1, values.size, buffer) return BufferMatrix(1, values.size, buffer)
} }
@ -26,12 +23,12 @@ public inline fun <reified T : Any> Structure2D.Companion.row(
size: Int, size: Int,
factory: BufferFactory<T> = Buffer.Companion::auto, factory: BufferFactory<T> = Buffer.Companion::auto,
noinline builder: (Int) -> T noinline builder: (Int) -> T
): FeaturedMatrix<T> { ): Matrix<T> {
val buffer = factory(size, builder) val buffer = factory(size, builder)
return BufferMatrix(1, size, buffer) return BufferMatrix(1, size, buffer)
} }
public fun <T : Any> Structure2D.Companion.column(vararg values: T): FeaturedMatrix<T> { public fun <T : Any> Structure2D.Companion.column(vararg values: T): Matrix<T> {
val buffer = values.asBuffer() val buffer = values.asBuffer()
return BufferMatrix(values.size, 1, buffer) return BufferMatrix(values.size, 1, buffer)
} }
@ -40,7 +37,7 @@ public inline fun <reified T : Any> Structure2D.Companion.column(
size: Int, size: Int,
factory: BufferFactory<T> = Buffer.Companion::auto, factory: BufferFactory<T> = Buffer.Companion::auto,
noinline builder: (Int) -> T noinline builder: (Int) -> T
): FeaturedMatrix<T> { ): Matrix<T> {
val buffer = factory(size, builder) val buffer = factory(size, builder)
return BufferMatrix(size, 1, buffer) return BufferMatrix(size, 1, buffer)
} }

View File

@ -133,8 +133,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
public override fun multiply(a: Matrix<T>, k: Number): M = public override fun multiply(a: Matrix<T>, k: Number): M =
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
public operator fun Number.times(matrix: FeaturedMatrix<T>): M = multiply(matrix, this)
public override operator fun Matrix<T>.times(value: T): M = public override operator fun Matrix<T>.times(value: T): M =
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
} }

View File

@ -11,17 +11,19 @@ public interface MatrixFeature
/** /**
* Matrices with this feature are considered to have only diagonal non-null elements. * Matrices with this feature are considered to have only diagonal non-null elements.
*/ */
public object DiagonalFeature : MatrixFeature public interface DiagonalFeature : MatrixFeature{
public companion object: DiagonalFeature
}
/** /**
* Matrices with this feature have all zero elements. * Matrices with this feature have all zero elements.
*/ */
public object ZeroFeature : MatrixFeature public object ZeroFeature : DiagonalFeature
/** /**
* Matrices with this feature have unit elements on diagonal and zero elements in all other places. * Matrices with this feature have unit elements on diagonal and zero elements in all other places.
*/ */
public object UnitFeature : MatrixFeature public object UnitFeature : DiagonalFeature
/** /**
* Matrices with this feature can be inverted: [inverse] = `a`<sup>-1</sup> where `a` is the owning matrix. * Matrices with this feature can be inverted: [inverse] = `a`<sup>-1</sup> where `a` is the owning matrix.
@ -76,17 +78,17 @@ public interface LupDecompositionFeature<T : Any> : MatrixFeature {
/** /**
* The lower triangular matrix in this decomposition. It may have [LFeature]. * The lower triangular matrix in this decomposition. It may have [LFeature].
*/ */
public val l: FeaturedMatrix<T> public val l: Matrix<T>
/** /**
* The upper triangular matrix in this decomposition. It may have [UFeature]. * The upper triangular matrix in this decomposition. It may have [UFeature].
*/ */
public val u: FeaturedMatrix<T> public val u: Matrix<T>
/** /**
* The permutation matrix in this decomposition. * The permutation matrix in this decomposition.
*/ */
public val p: FeaturedMatrix<T> public val p: Matrix<T>
} }
/** /**
@ -104,12 +106,12 @@ public interface QRDecompositionFeature<T : Any> : MatrixFeature {
/** /**
* The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature].
*/ */
public val q: FeaturedMatrix<T> public val q: Matrix<T>
/** /**
* The upper triangular matrix in this decomposition. It may have [UFeature]. * The upper triangular matrix in this decomposition. It may have [UFeature].
*/ */
public val r: FeaturedMatrix<T> public val r: Matrix<T>
} }
/** /**
@ -122,7 +124,7 @@ public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature {
/** /**
* The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature].
*/ */
public val l: FeaturedMatrix<T> public val l: Matrix<T>
} }
/** /**
@ -135,17 +137,17 @@ public interface SingularValueDecompositionFeature<T : Any> : MatrixFeature {
/** /**
* The matrix in this decomposition. It is unitary, and it consists from left singular vectors. * The matrix in this decomposition. It is unitary, and it consists from left singular vectors.
*/ */
public val u: FeaturedMatrix<T> public val u: Matrix<T>
/** /**
* The matrix in this decomposition. Its main diagonal elements are singular values. * The matrix in this decomposition. Its main diagonal elements are singular values.
*/ */
public val s: FeaturedMatrix<T> public val s: Matrix<T>
/** /**
* The matrix in this decomposition. It is unitary, and it consists from right singular vectors. * The matrix in this decomposition. It is unitary, and it consists from right singular vectors.
*/ */
public val v: FeaturedMatrix<T> public val v: Matrix<T>
/** /**
* The buffer of singular values of this SVD. * The buffer of singular values of this SVD.

View File

@ -1,35 +1,57 @@
package kscience.kmath.linear package kscience.kmath.linear
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.operations.Ring import kscience.kmath.operations.Ring
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.Structure2D import kscience.kmath.structures.Structure2D
import kscience.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
import kscience.kmath.structures.getFeature
import kotlin.math.sqrt import kotlin.math.sqrt
import kotlin.reflect.KClass
import kotlin.reflect.safeCast
/** /**
* A [Matrix] that holds [MatrixFeature] objects. * A [Matrix] that holds [MatrixFeature] objects.
* *
* @param T the type of items. * @param T the type of items.
*/ */
public interface FeaturedMatrix<T : Any> : Matrix<T> { public class MatrixWrapper<T : Any>(
public override val shape: IntArray get() = intArrayOf(rowNum, colNum) public val matrix: Matrix<T>,
public val features: Set<MatrixFeature>,
) : Matrix<T> by matrix {
/** /**
* The set of features this matrix possesses. * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/ */
public val features: Set<MatrixFeature> @UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) })
/** override fun equals(other: Any?): Boolean = matrix == other
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure. override fun hashCode(): Int = matrix.hashCode()
* override fun toString(): String {
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to return "MatrixWrapper(matrix=$matrix, features=$features)"
* add only those features that are valid. }
*/
public fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
public companion object
} }
/**
* Add a single feature to a [Matrix]
*/
public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) {
MatrixWrapper(matrix, features + newFeature)
} else {
MatrixWrapper(this, setOf(newFeature))
}
/**
* Add a collection of features to a [Matrix]
*/
public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> =
if (this is MatrixWrapper) {
MatrixWrapper(matrix, features + newFeatures)
} else {
MatrixWrapper(this, newFeatures.toSet())
}
public inline fun Structure2D.Companion.real( public inline fun Structure2D.Companion.real(
rows: Int, rows: Int,
columns: Int, columns: Int,
@ -39,51 +61,37 @@ public inline fun Structure2D.Companion.real(
/** /**
* Build a square matrix from given elements. * Build a square matrix from given elements.
*/ */
public fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> { public fun <T : Any> Structure2D.Companion.square(vararg elements: T): Matrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt() val size: Int = sqrt(elements.size.toDouble()).toInt()
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer) return BufferMatrix(size, size, buffer)
} }
public val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
/**
* Check if matrix has the given feature class
*/
public inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
features.find { it is T } != null
/**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/
public inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
features.filterIsInstance<T>().firstOrNull()
/** /**
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created * Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/ */
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.one(rows: Int, columns: Int): FeaturedMatrix<T> = public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.one(rows: Int, columns: Int): Matrix<T> =
VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> VirtualMatrix(rows, columns) { i, j ->
if (i == j) elementContext.one else elementContext.zero if (i == j) elementContext.one else elementContext.zero
} } + UnitFeature
/** /**
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.zero(rows: Int, columns: Int): FeaturedMatrix<T> = public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.zero(rows: Int, columns: Int): Matrix<T> =
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature
public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature
/** /**
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
*/ */
@OptIn(UnstableKMathAPI::class)
public fun <T : Any> Matrix<T>.transpose(): Matrix<T> { public fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix( return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
colNum, colNum,
rowNum, rowNum,
setOf(TransposedFeature(this)) ) { i, j -> get(j, i) } + TransposedFeature(this)
) { i, j -> get(j, i) }
} }

View File

@ -1,9 +1,6 @@
package kscience.kmath.linear package kscience.kmath.linear
import kscience.kmath.operations.RealField
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.MutableBuffer
import kscience.kmath.structures.MutableBufferFactory
import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@ -22,9 +19,9 @@ public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
produce(rowNum, colNum) { i, j -> get(i, j) } produce(rowNum, colNum) { i, j -> get(i, j) }
} }
public fun one(rows: Int, columns: Int): FeaturedMatrix<Double> = VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> public fun one(rows: Int, columns: Int): Matrix<Double> = VirtualMatrix(rows, columns) { i, j ->
if (i == j) 1.0 else 0.0 if (i == j) 1.0 else 0.0
} } + DiagonalFeature
public override infix fun Matrix<Double>.dot(other: Matrix<Double>): BufferMatrix<Double> { public override infix fun Matrix<Double>.dot(other: Matrix<Double>): BufferMatrix<Double> {
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
@ -69,16 +66,3 @@ public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
* Partially optimized real-valued matrix * Partially optimized real-valued matrix
*/ */
public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> {
// Use existing decomposition if it is provided by matrix
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
val decomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
return decomposition.solveWithLUP(bufferFactory, b)
}
/**
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
*/
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum))

View File

@ -5,31 +5,16 @@ import kscience.kmath.structures.Matrix
public class VirtualMatrix<T : Any>( public class VirtualMatrix<T : Any>(
override val rowNum: Int, override val rowNum: Int,
override val colNum: Int, override val colNum: Int,
override val features: Set<MatrixFeature> = emptySet(),
public val generator: (i: Int, j: Int) -> T public val generator: (i: Int, j: Int) -> T
) : FeaturedMatrix<T> { ) : Matrix<T> {
public constructor(
rowNum: Int,
colNum: Int,
vararg features: MatrixFeature,
generator: (i: Int, j: Int) -> T
) : this(
rowNum,
colNum,
setOf(*features),
generator
)
override val shape: IntArray get() = intArrayOf(rowNum, colNum) override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override operator fun get(i: Int, j: Int): T = generator(i, j) override operator fun get(i: Int, j: Int): T = generator(i, j)
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
VirtualMatrix(rowNum, colNum, this.features + features, generator)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is FeaturedMatrix<*>) return false if (other !is Matrix<*>) return false
if (rowNum != other.rowNum) return false if (rowNum != other.rowNum) return false
if (colNum != other.colNum) return false if (colNum != other.colNum) return false
@ -40,21 +25,9 @@ public class VirtualMatrix<T : Any>(
override fun hashCode(): Int { override fun hashCode(): Int {
var result = rowNum var result = rowNum
result = 31 * result + colNum result = 31 * result + colNum
result = 31 * result + features.hashCode()
result = 31 * result + generator.hashCode() result = 31 * result + generator.hashCode()
return result return result
} }
public companion object {
/**
* Wrap a matrix adding additional features to it
*/
public fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
return if (matrix is VirtualMatrix)
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
else
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> matrix[i, j] }
}
}
} }

View File

@ -179,13 +179,15 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
* A field for [Int] without boxing. Does not produce corresponding ring element. * A field for [Int] without boxing. Does not produce corresponding ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object IntRing : Ring<Int>, Norm<Int, Int> { public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
public override val zero: Int public override val zero: Int
get() = 0 get() = 0
public override val one: Int public override val one: Int
get() = 1 get() = 1
override fun number(value: Number): Int = value.toInt()
public override inline fun add(a: Int, b: Int): Int = a + b public override inline fun add(a: Int, b: Int): Int = a + b
public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
@ -203,13 +205,15 @@ public object IntRing : Ring<Int>, Norm<Int, Int> {
* A field for [Short] without boxing. Does not produce appropriate ring element. * A field for [Short] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ShortRing : Ring<Short>, Norm<Short, Short> { public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short> {
public override val zero: Short public override val zero: Short
get() = 0 get() = 0
public override val one: Short public override val one: Short
get() = 1 get() = 1
override fun number(value: Number): Short = value.toShort()
public override inline fun add(a: Short, b: Short): Short = (a + b).toShort() public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
@ -227,13 +231,15 @@ public object ShortRing : Ring<Short>, Norm<Short, Short> {
* A field for [Byte] without boxing. Does not produce appropriate ring element. * A field for [Byte] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> { public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
public override val zero: Byte public override val zero: Byte
get() = 0 get() = 0
public override val one: Byte public override val one: Byte
get() = 1 get() = 1
override fun number(value: Number): Byte = value.toByte()
public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
@ -251,13 +257,15 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
* A field for [Double] without boxing. Does not produce appropriate ring element. * A field for [Double] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object LongRing : Ring<Long>, Norm<Long, Long> { public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
public override val zero: Long public override val zero: Long
get() = 0L get() = 0L
public override val one: Long public override val one: Long
get() = 1L get() = 1L
override fun number(value: Number): Long = value.toLong()
public override inline fun add(a: Long, b: Long): Long = a + b public override inline fun add(a: Long, b: Long): Long = a + b
public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()

View File

@ -1,5 +1,6 @@
package kscience.kmath.structures package kscience.kmath.structures
import kscience.kmath.misc.UnstableKMathAPI
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.native.concurrent.ThreadLocal import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass import kotlin.reflect.KClass
@ -42,6 +43,13 @@ public interface NDStructure<T> {
public override fun equals(other: Any?): Boolean public override fun equals(other: Any?): Boolean
public override fun hashCode(): Int public override fun hashCode(): Int
/**
* Feature is additional property or hint that does not directly affect the structure, but could in some cases help
* optimize operations and performance. If the feature is not present, null is defined.
*/
@UnstableKMathAPI
public fun <T : Any> getFeature(type: KClass<T>): T? = null
public companion object { public companion object {
/** /**
* Indicates whether some [NDStructure] is equal to another one. * Indicates whether some [NDStructure] is equal to another one.
@ -121,6 +129,9 @@ public interface NDStructure<T> {
*/ */
public operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index) public operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
@UnstableKMathAPI
public inline fun <reified T : Any> NDStructure<*>.getFeature(): T? = getFeature(T::class)
/** /**
* Represents mutable [NDStructure]. * Represents mutable [NDStructure].
*/ */

View File

@ -9,12 +9,14 @@ public interface Structure2D<T> : NDStructure<T> {
/** /**
* The number of rows in this structure. * The number of rows in this structure.
*/ */
public val rowNum: Int get() = shape[0] public val rowNum: Int
/** /**
* The number of columns in this structure. * The number of columns in this structure.
*/ */
public val colNum: Int get() = shape[1] public val colNum: Int
public override val shape: IntArray get() = intArrayOf(rowNum, colNum)
/** /**
* The buffer of rows of this structure. It gets elements from the structure dynamically. * The buffer of rows of this structure. It gets elements from the structure dynamically.
@ -56,6 +58,9 @@ public interface Structure2D<T> : NDStructure<T> {
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> { private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
override operator fun get(i: Int, j: Int): T = structure[i, j] override operator fun get(i: Int, j: Int): T = structure[i, j]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()

View File

@ -40,6 +40,8 @@ public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
private val structure: Structure2D<T>, private val structure: Structure2D<T>,
) : DMatrix<T, R, C> { ) : DMatrix<T, R, C> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1]
override operator fun get(i: Int, j: Int): T = structure[i, j] override operator fun get(i: Int, j: Int): T = structure[i, j]
} }
@ -147,6 +149,7 @@ public inline fun <reified D : Dimension> DMatrixContext<Double>.one(): DMatrix<
if (i == j) 1.0 else 0.0 if (i == j) 1.0 else 0.0
} }
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> = produce { _, _ -> public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> =
0.0 produce { _, _ ->
} 0.0
}

View File

@ -1,10 +1,13 @@
package kscience.kmath.ejml package kscience.kmath.ejml
import kscience.kmath.linear.* import kscience.kmath.linear.*
import kscience.kmath.structures.NDStructure import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix
import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix import org.ejml.simple.SimpleMatrix
import kotlin.reflect.KClass
import kotlin.reflect.cast
/** /**
* Represents featured matrix over EJML [SimpleMatrix]. * Represents featured matrix over EJML [SimpleMatrix].
@ -12,85 +15,65 @@ import org.ejml.simple.SimpleMatrix
* @property origin the underlying [SimpleMatrix]. * @property origin the underlying [SimpleMatrix].
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public class EjmlMatrix( public inline class EjmlMatrix(
public val origin: SimpleMatrix, public val origin: SimpleMatrix,
features: Set<MatrixFeature> = emptySet() ) : Matrix<Double> {
) : FeaturedMatrix<Double> { public override val rowNum: Int get() = origin.numRows()
public override val rowNum: Int
get() = origin.numRows()
public override val colNum: Int public override val colNum: Int get() = origin.numCols()
get() = origin.numCols()
public override val shape: IntArray by lazy { intArrayOf(rowNum, colNum) } @UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
public override val features: Set<MatrixFeature> = hashSetOf( InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
object : InverseMatrixFeature<Double> { override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
override val inverse: FeaturedMatrix<Double> by lazy { EjmlMatrix(origin.invert()) } }
}, DeterminantFeature::class -> object : DeterminantFeature<Double> {
object : DeterminantFeature<Double> {
override val determinant: Double by lazy(origin::determinant) override val determinant: Double by lazy(origin::determinant)
}, }
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
object : SingularValueDecompositionFeature<Double> {
private val svd by lazy { private val svd by lazy {
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false) DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
.apply { decompose(origin.ddrm.copy()) } .apply { decompose(origin.ddrm.copy()) }
} }
override val u: FeaturedMatrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) } override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
override val s: FeaturedMatrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) } override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
override val v: FeaturedMatrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) } override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) }
}, }
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
object : QRDecompositionFeature<Double> {
private val qr by lazy { private val qr by lazy {
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) } DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
} }
override val q: FeaturedMatrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } override val q: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) }
override val r: FeaturedMatrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) }
}, }
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
object : CholeskyDecompositionFeature<Double> { override val l: Matrix<Double> by lazy {
override val l: FeaturedMatrix<Double> by lazy {
val cholesky = val cholesky =
DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) }
EjmlMatrix(SimpleMatrix(cholesky.getT(null)), setOf(LFeature)) EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
} }
}, }
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
object : LupDecompositionFeature<Double> {
private val lup by lazy { private val lup by lazy {
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) }
} }
override val l: FeaturedMatrix<Double> by lazy { override val l: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getLower(null)), setOf(LFeature)) EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
} }
override val u: FeaturedMatrix<Double> by lazy { override val u: Matrix<Double> by lazy {
EjmlMatrix(SimpleMatrix(lup.getUpper(null)), setOf(UFeature)) EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
} }
override val p: FeaturedMatrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
}, }
) union features else -> null
}?.let{type.cast(it)}
public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix =
EjmlMatrix(origin, this.features + features)
public override operator fun get(i: Int, j: Int): Double = origin[i, j] public override operator fun get(i: Int, j: Int): Double = origin[i, j]
public override fun equals(other: Any?): Boolean {
if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0)
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
}
public override fun hashCode(): Int = origin.hashCode()
public override fun toString(): String = "EjmlMatrix($origin)"
} }

View File

@ -2,23 +2,29 @@ package kscience.kmath.ejml
import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.InverseMatrixFeature
import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.MatrixWrapper
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
import kscience.kmath.linear.getFeature import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.getFeature
import org.ejml.simple.SimpleMatrix import org.ejml.simple.SimpleMatrix
/**
* Converts this matrix to EJML one.
*/
public fun Matrix<Double>.toEjml(): EjmlMatrix =
if (this is EjmlMatrix) this else EjmlMatrixContext.produce(rowNum, colNum) { i, j -> get(i, j) }
/** /**
* Represents context of basic operations operating with [EjmlMatrix]. * Represents context of basic operations operating with [EjmlMatrix].
* *
* @author Iaroslav Postovalov * @author Iaroslav Postovalov
*/ */
public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> { public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
/**
* Converts this matrix to EJML one.
*/
public fun Matrix<Double>.toEjml(): EjmlMatrix = when {
this is EjmlMatrix -> this
this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix
else -> produce(rowNum, colNum) { i, j -> get(i, j) }
}
/** /**
* Converts this vector to EJML one. * Converts this vector to EJML one.
*/ */
@ -80,6 +86,7 @@ public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMa
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector = public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
@OptIn(UnstableKMathAPI::class)
public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature<InverseMatrixFeature<Double>>()!!.inverse as EjmlMatrix public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature<InverseMatrixFeature<Double>>()!!.inverse as EjmlMatrix
public fun EjmlMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> = matrix.toEjml().inverted() public fun EjmlMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> = matrix.toEjml().inverted()

View File

@ -3,7 +3,8 @@ package kscience.kmath.ejml
import kscience.kmath.linear.DeterminantFeature import kscience.kmath.linear.DeterminantFeature
import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.LupDecompositionFeature
import kscience.kmath.linear.MatrixFeature import kscience.kmath.linear.MatrixFeature
import kscience.kmath.linear.getFeature import kscience.kmath.linear.plus
import kscience.kmath.structures.getFeature
import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix import org.ejml.simple.SimpleMatrix
import kotlin.random.Random import kotlin.random.Random
@ -58,7 +59,7 @@ internal class EjmlMatrixTest {
@Test @Test
fun suggestFeature() { fun suggestFeature() {
assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature<SomeFeature>()) assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())
} }
@Test @Test

View File

@ -1,8 +1,12 @@
package kscience.kmath.real package kscience.kmath.real
import kscience.kmath.linear.* import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.VirtualMatrix
import kscience.kmath.linear.inverseWithLUP
import kscience.kmath.linear.real
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import kscience.kmath.structures.Matrix
import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
import kscience.kmath.structures.asIterable import kscience.kmath.structures.asIterable
import kotlin.math.pow import kotlin.math.pow
@ -19,7 +23,7 @@ import kotlin.math.pow
* Functions that help create a real (Double) matrix * Functions that help create a real (Double) matrix
*/ */
public typealias RealMatrix = FeaturedMatrix<Double> public typealias RealMatrix = Matrix<Double>
public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
MatrixContext.real.produce(rowNum, colNum, initializer) MatrixContext.real.produce(rowNum, colNum, initializer)

View File

@ -1,6 +1,5 @@
package kaceince.kmath.real package kaceince.kmath.real
import kscience.kmath.linear.VirtualMatrix
import kscience.kmath.linear.build import kscience.kmath.linear.build
import kscience.kmath.real.* import kscience.kmath.real.*
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
@ -42,7 +41,7 @@ internal class RealMatrixTest {
1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 1.0, 2.0 0.0, 1.0, 2.0
) )
assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3)) assertEquals(matrix2, matrix1.repeatStackVertical(3))
} }
@Test @Test