Matrix refactoring
This commit is contained in:
parent
4065fda99a
commit
14f05eb1e1
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import koma.matrix.ejml.EJMLMatrixFactory
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
@ -30,7 +31,7 @@ fun main() {
|
||||
|
||||
//commons-math
|
||||
|
||||
val cmContext = CMMatrixContext
|
||||
val cmContext = CMLUPSolver
|
||||
|
||||
val commonsTime = measureTimeMillis {
|
||||
cmContext.run {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import koma.matrix.ejml.EJMLMatrixFactory
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import kotlin.random.Random
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
|
@ -193,6 +193,7 @@ subprojects {
|
||||
targets.all {
|
||||
sourceSets.all {
|
||||
languageSettings.progressiveMode = true
|
||||
languageSettings.enableLanguageFeature("InlineClasses")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,13 +3,14 @@ package scientifik.kmath.linear
|
||||
import org.apache.commons.math3.linear.*
|
||||
import org.apache.commons.math3.linear.RealMatrix
|
||||
import org.apache.commons.math3.linear.RealVector
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : Matrix<Double> {
|
||||
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||
override val rowNum: Int get() = origin.rowDimension
|
||||
override val colNum: Int get() = origin.columnDimension
|
||||
|
||||
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
||||
if(origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||
}.toSet()
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||
@ -45,28 +46,13 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
|
||||
|
||||
fun RealVector.toPoint() = CMVector(this)
|
||||
|
||||
object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
|
||||
object CMMatrixContext : MatrixContext<Double> {
|
||||
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
||||
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
||||
return CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
||||
}
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.inverse.toMatrix()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||
|
||||
@ -87,6 +73,23 @@ object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
|
||||
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
|
||||
}
|
||||
|
||||
object CMLUPSolver: LinearSolver<Double>{
|
||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
||||
}
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.inverse.toMatrix()
|
||||
}
|
||||
}
|
||||
|
||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
||||
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))
|
||||
|
||||
|
@ -13,34 +13,6 @@ kotlin {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(project(":kmath-memory"))
|
||||
api(kotlin("stdlib"))
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-jdk8"))
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-js"))
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
// mingwMain {
|
||||
|
@ -2,7 +2,6 @@ package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.*
|
||||
import kotlin.jvm.JvmSynthetic
|
||||
|
||||
/**
|
||||
* Basic implementation of Matrix space based on [NDStructure]
|
||||
@ -25,7 +24,7 @@ class BufferMatrix<T : Any>(
|
||||
override val colNum: Int,
|
||||
val buffer: Buffer<out T>,
|
||||
override val features: Set<MatrixFeature> = emptySet()
|
||||
) : Matrix<T> {
|
||||
) : FeaturedMatrix<T> {
|
||||
|
||||
init {
|
||||
if (buffer.size != rowNum * colNum) {
|
||||
|
@ -8,6 +8,9 @@ import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
/**
|
||||
* Basic operations on matrices. Operates on [Matrix]
|
||||
*/
|
||||
interface MatrixContext<T : Any> {
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
@ -101,18 +104,18 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
||||
|
||||
operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this
|
||||
operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
|
||||
|
||||
override fun Matrix<T>.times(value: T): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Specialized 2-d structure
|
||||
* A 2d structure plus optional matrix-specific features
|
||||
*/
|
||||
interface Matrix<T : Any> : Structure2D<T> {
|
||||
val rowNum: Int
|
||||
val colNum: Int
|
||||
interface FeaturedMatrix<T : Any> : Matrix<T> {
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
val features: Set<MatrixFeature>
|
||||
|
||||
@ -122,70 +125,54 @@ interface Matrix<T : Any> : Structure2D<T> {
|
||||
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
||||
* add only those features that are valid.
|
||||
*/
|
||||
fun suggestFeature(vararg features: MatrixFeature): Matrix<T>
|
||||
|
||||
override fun get(index: IntArray): T = get(index[0], index[1])
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
val rows: Point<Point<T>>
|
||||
get() = VirtualBuffer(rowNum) { i ->
|
||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||
}
|
||||
|
||||
val columns: Point<Point<T>>
|
||||
get() = VirtualBuffer(colNum) { j ->
|
||||
VirtualBuffer(rowNum) { i -> get(i, j) }
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in (0 until rowNum)) {
|
||||
for (j in (0 until colNum)) {
|
||||
yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
}
|
||||
}
|
||||
fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||
|
||||
companion object {
|
||||
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||
MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
fun <T : Any> square(vararg elements: T): Matrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
fun <T : Any> build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
||||
}
|
||||
}
|
||||
|
||||
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||
MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
fun <T : Any> Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
|
||||
|
||||
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
|
||||
operator fun invoke(vararg elements: T): Matrix<T> {
|
||||
operator fun invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
}
|
||||
|
||||
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { it is T } != null
|
||||
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
|
||||
features.find { it is T } != null
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull()
|
||||
inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
|
||||
features.filterIsInstance<T>().firstOrNull()
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
@ -194,7 +181,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
|
||||
|
||||
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature
|
@ -2,12 +2,12 @@ package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.MutableBuffer
|
||||
import scientifik.kmath.structures.*
|
||||
import scientifik.kmath.structures.MutableBuffer.Companion.boxing
|
||||
import scientifik.kmath.structures.MutableBufferFactory
|
||||
import scientifik.kmath.structures.NDStructure
|
||||
import scientifik.kmath.structures.get
|
||||
|
||||
/**
|
||||
* Common implementation of [LUPDecompositionFeature]
|
||||
*/
|
||||
class LUPDecomposition<T : Comparable<T>>(
|
||||
private val elementContext: Ring<T>,
|
||||
internal val lu: NDStructure<T>,
|
||||
@ -20,7 +20,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
*
|
||||
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||
*/
|
||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||
when {
|
||||
j < i -> lu[i, j]
|
||||
j == i -> elementContext.one
|
||||
@ -34,7 +34,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
*
|
||||
* U is an upper-triangular matrix including the diagonal
|
||||
*/
|
||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||
if (j >= i) lu[i, j] else elementContext.zero
|
||||
}
|
||||
|
||||
@ -45,7 +45,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||
* each row and each column, all other elements being set to [Ring.zero].
|
||||
*/
|
||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
@ -62,7 +62,9 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Common implementation of LUP [LinearSolver] based on commons-math code
|
||||
*/
|
||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
val context: GenericMatrixContext<T, F>,
|
||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||
|
@ -3,6 +3,7 @@ package scientifik.kmath.linear
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Norm
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.Matrix
|
||||
import scientifik.kmath.structures.VirtualBuffer
|
||||
import scientifik.kmath.structures.asSequence
|
||||
|
||||
@ -10,9 +11,9 @@ import scientifik.kmath.structures.asSequence
|
||||
/**
|
||||
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||
*/
|
||||
interface LinearSolver<T : Any> {
|
||||
interface LinearSolver<T : Any> {
|
||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).asPoint()
|
||||
fun inverse(a: Matrix<T>): Matrix<T>
|
||||
}
|
||||
|
||||
@ -32,14 +33,14 @@ object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||
typealias RealVector = Vector<Double, RealField>
|
||||
typealias RealMatrix = Matrix<Double>
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Convert matrix to vector if it is possible
|
||||
*/
|
||||
fun <T: Any> Matrix<T>.toVector(): Point<T> =
|
||||
fun <T : Any> Matrix<T>.asPoint(): Point<T> =
|
||||
if (this.colNum == 1) {
|
||||
VirtualBuffer(rowNum){ get(it, 0) }
|
||||
} else error("Can't convert matrix with more than one column to vector")
|
||||
VirtualBuffer(rowNum) { get(it, 0) }
|
||||
} else {
|
||||
error("Can't convert matrix with more than one column to vector")
|
||||
}
|
||||
|
||||
fun <T: Any> Point<T>.toMatrix(): Matrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
||||
fun <T : Any> Point<T>.toMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }
|
@ -25,7 +25,7 @@ object UnitFeature : MatrixFeature
|
||||
* Inverted matrix feature
|
||||
*/
|
||||
interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||
val inverse: Matrix<T>
|
||||
val inverse: FeaturedMatrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
@ -54,9 +54,9 @@ object UFeature: MatrixFeature
|
||||
* TODO add documentation
|
||||
*/
|
||||
interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
||||
val l: Matrix<T>
|
||||
val u: Matrix<T>
|
||||
val p: Matrix<T>
|
||||
val l: FeaturedMatrix<T>
|
||||
val u: FeaturedMatrix<T>
|
||||
val p: FeaturedMatrix<T>
|
||||
}
|
||||
|
||||
//TODO add sparse matrix feature
|
@ -1,11 +1,16 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
class VirtualMatrix<T : Any>(
|
||||
override val rowNum: Int,
|
||||
override val colNum: Int,
|
||||
override val features: Set<MatrixFeature> = emptySet(),
|
||||
val generator: (i: Int, j: Int) -> T
|
||||
) : Matrix<T> {
|
||||
) : FeaturedMatrix<T> {
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
override fun get(i: Int, j: Int): T = generator(i, j)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature) =
|
||||
@ -13,7 +18,7 @@ class VirtualMatrix<T : Any>(
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is Matrix<*>) return false
|
||||
if (other !is FeaturedMatrix<*>) return false
|
||||
|
||||
if (rowNum != other.rowNum) return false
|
||||
if (colNum != other.colNum) return false
|
||||
@ -34,7 +39,7 @@ class VirtualMatrix<T : Any>(
|
||||
/**
|
||||
* Wrap a matrix adding additional features to it
|
||||
*/
|
||||
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> {
|
||||
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
|
||||
return if (matrix is VirtualMatrix) {
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||
} else {
|
||||
|
@ -5,8 +5,7 @@ interface NDStructure<T> {
|
||||
|
||||
val shape: IntArray
|
||||
|
||||
val dimension
|
||||
get() = shape.size
|
||||
val dimension get() = shape.size
|
||||
|
||||
operator fun get(index: IntArray): T
|
||||
|
||||
|
@ -27,20 +27,6 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
Structure1DWrapper(this)
|
||||
} else {
|
||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
fun <T> NDBuffer<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
Buffer1DWrapper(this.buffer)
|
||||
} else {
|
||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
/**
|
||||
* A structure wrapper for buffer
|
||||
@ -56,39 +42,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
|
||||
override fun get(index: Int): T = buffer.get(index)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
|
||||
if( this is NDBuffer){
|
||||
Buffer1DWrapper(this.buffer)
|
||||
} else {
|
||||
Structure1DWrapper(this)
|
||||
}
|
||||
} else {
|
||||
error("Can't create 1d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represent this buffer as 1D structure
|
||||
*/
|
||||
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
|
||||
|
||||
/**
|
||||
* A structure that is guaranteed to be two-dimensional
|
||||
*/
|
||||
interface Structure2D<T> : NDStructure<T> {
|
||||
operator fun get(i: Int, j: Int): T
|
||||
|
||||
override fun get(index: IntArray): T {
|
||||
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
||||
return get(index[0], index[1])
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for nd-structure
|
||||
*/
|
||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||
Structure2DWrapper(this)
|
||||
} else {
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
@ -0,0 +1,79 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
/**
|
||||
* A structure that is guaranteed to be two-dimensional
|
||||
*/
|
||||
interface Structure2D<T> : NDStructure<T> {
|
||||
val rowNum: Int get() = shape[0]
|
||||
val colNum: Int get() = shape[1]
|
||||
|
||||
operator fun get(i: Int, j: Int): T
|
||||
|
||||
override fun get(index: IntArray): T {
|
||||
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
|
||||
return get(index[0], index[1])
|
||||
}
|
||||
|
||||
|
||||
val rows: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(rowNum) { i ->
|
||||
VirtualBuffer(colNum) { j -> get(i, j) }
|
||||
}
|
||||
|
||||
val columns: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(colNum) { j ->
|
||||
VirtualBuffer(rowNum) { i -> get(i, j) }
|
||||
}
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in (0 until rowNum)) {
|
||||
for (j in (0 until colNum)) {
|
||||
yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A 2D wrapper for nd-structure
|
||||
*/
|
||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||
override fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
|
||||
*/
|
||||
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
|
||||
Structure2DWrapper(this)
|
||||
} else {
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
|
||||
*/
|
||||
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
|
||||
object : Structure1D<T> {
|
||||
override fun get(index: Int): T = get(index, 0)
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum)
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
|
||||
|
||||
override val size: Int get() = rowNum
|
||||
}
|
||||
} else {
|
||||
error("Can't convert matrix with more than one column to vector")
|
||||
}
|
||||
|
||||
|
||||
typealias Matrix<T> = Structure2D<T>
|
@ -44,7 +44,7 @@ class MatrixTest {
|
||||
|
||||
@Test
|
||||
fun testBuilder() {
|
||||
val matrix = Matrix.build<Double>(2, 3)(
|
||||
val matrix = FeaturedMatrix.build<Double>(2, 3)(
|
||||
1.0, 0.0, 0.0,
|
||||
0.0, 1.0, 2.0
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ class RealLUSolverTest {
|
||||
|
||||
@Test
|
||||
fun testInvert() {
|
||||
val matrix = Matrix.square(
|
||||
val matrix = FeaturedMatrix.square(
|
||||
3.0, 1.0,
|
||||
1.0, 3.0
|
||||
)
|
||||
@ -31,7 +31,7 @@ class RealLUSolverTest {
|
||||
|
||||
val inverted = LUSolver.real.inverse(decomposed)
|
||||
|
||||
val expected = Matrix.square(
|
||||
val expected = FeaturedMatrix.square(
|
||||
0.375, -0.125,
|
||||
-0.125, 0.375
|
||||
)
|
||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.linear
|
||||
|
||||
import koma.extensions.fill
|
||||
import koma.matrix.MatrixFactory
|
||||
import scientifik.kmath.structures.Matrix
|
||||
|
||||
class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>,
|
||||
LinearSolver<T> {
|
||||
@ -48,24 +49,25 @@ class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T
|
||||
KomaMatrix(a.toKoma().origin.inv())
|
||||
}
|
||||
|
||||
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) :
|
||||
Matrix<T> {
|
||||
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
|
||||
override val rowNum: Int get() = origin.numRows()
|
||||
override val colNum: Int get() = origin.numCols()
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(origin.numRows(),origin.numCols())
|
||||
|
||||
override val features: Set<MatrixFeature> = features ?: setOf(
|
||||
object : DeterminantFeature<T> {
|
||||
override val determinant: T get() = origin.det()
|
||||
},
|
||||
object : LUPDecompositionFeature<T> {
|
||||
private val lup by lazy { origin.LU() }
|
||||
override val l: Matrix<T> get() = KomaMatrix(lup.second)
|
||||
override val u: Matrix<T> get() = KomaMatrix(lup.third)
|
||||
override val p: Matrix<T> get() = KomaMatrix(lup.first)
|
||||
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
||||
override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
|
||||
override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
|
||||
}
|
||||
)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): Matrix<T> =
|
||||
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
||||
KomaMatrix(this.origin, this.features + features)
|
||||
|
||||
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||
|
@ -8,43 +8,9 @@ val ioVersion: String by rootProject.extra
|
||||
kotlin {
|
||||
jvm()
|
||||
js()
|
||||
|
||||
sourceSets {
|
||||
val commonMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib"))
|
||||
}
|
||||
}
|
||||
val commonTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-common"))
|
||||
implementation(kotlin("test-annotations-common"))
|
||||
}
|
||||
}
|
||||
val jvmMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-jdk8"))
|
||||
}
|
||||
}
|
||||
val jvmTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test"))
|
||||
implementation(kotlin("test-junit"))
|
||||
}
|
||||
}
|
||||
val jsMain by getting {
|
||||
dependencies {
|
||||
api(kotlin("stdlib-js"))
|
||||
}
|
||||
}
|
||||
val jsTest by getting {
|
||||
dependencies {
|
||||
implementation(kotlin("test-js"))
|
||||
}
|
||||
}
|
||||
// mingwMain {
|
||||
// }
|
||||
// mingwTest {
|
||||
// }
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user