Matrix refactoring

This commit is contained in:
Alexander Nozik 2019-04-07 12:27:49 +03:00
parent 4065fda99a
commit 14f05eb1e1
18 changed files with 196 additions and 210 deletions

View File

@ -1,6 +1,7 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import koma.matrix.ejml.EJMLMatrixFactory import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.structures.Matrix
import kotlin.random.Random import kotlin.random.Random
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -30,7 +31,7 @@ fun main() {
//commons-math //commons-math
val cmContext = CMMatrixContext val cmContext = CMLUPSolver
val commonsTime = measureTimeMillis { val commonsTime = measureTimeMillis {
cmContext.run { cmContext.run {

View File

@ -1,6 +1,7 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import koma.matrix.ejml.EJMLMatrixFactory import koma.matrix.ejml.EJMLMatrixFactory
import scientifik.kmath.structures.Matrix
import kotlin.random.Random import kotlin.random.Random
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis

View File

@ -193,6 +193,7 @@ subprojects {
targets.all { targets.all {
sourceSets.all { sourceSets.all {
languageSettings.progressiveMode = true languageSettings.progressiveMode = true
languageSettings.enableLanguageFeature("InlineClasses")
} }
} }

View File

@ -3,13 +3,14 @@ package scientifik.kmath.linear
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
import org.apache.commons.math3.linear.RealMatrix import org.apache.commons.math3.linear.RealMatrix
import org.apache.commons.math3.linear.RealVector import org.apache.commons.math3.linear.RealVector
import scientifik.kmath.structures.Matrix
class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : Matrix<Double> { class CMMatrix(val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
override val rowNum: Int get() = origin.rowDimension override val rowNum: Int get() = origin.rowDimension
override val colNum: Int get() = origin.columnDimension override val colNum: Int get() = origin.columnDimension
override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> { override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
if(origin is DiagonalMatrix) yield(DiagonalFeature) if (origin is DiagonalMatrix) yield(DiagonalFeature)
}.toSet() }.toSet()
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature) =
@ -45,28 +46,13 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
fun RealVector.toPoint() = CMVector(this) fun RealVector.toPoint() = CMVector(this)
object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> { object CMMatrixContext : MatrixContext<Double> {
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } } val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
return CMMatrix(Array2DRowRealMatrix(array)) return CMMatrix(Array2DRowRealMatrix(array))
} }
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toMatrix()
}
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toPoint()
}
override fun inverse(a: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.inverse.toMatrix()
}
override fun Matrix<Double>.dot(other: Matrix<Double>) = override fun Matrix<Double>.dot(other: Matrix<Double>) =
CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
@ -87,6 +73,23 @@ object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble())) CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
} }
object CMLUPSolver: LinearSolver<Double>{
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toMatrix()
}
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.solve(b.toCM().origin).toPoint()
}
override fun inverse(a: Matrix<Double>): CMMatrix {
val decomposition = LUDecomposition(a.toCM().origin)
return decomposition.solver.inverse.toMatrix()
}
}
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin)) operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin)) operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))

View File

@ -13,34 +13,6 @@ kotlin {
val commonMain by getting { val commonMain by getting {
dependencies { dependencies {
api(project(":kmath-memory")) api(project(":kmath-memory"))
api(kotlin("stdlib"))
}
}
val commonTest by getting {
dependencies {
implementation(kotlin("test-common"))
implementation(kotlin("test-annotations-common"))
}
}
val jvmMain by getting {
dependencies {
api(kotlin("stdlib-jdk8"))
}
}
val jvmTest by getting {
dependencies {
implementation(kotlin("test"))
implementation(kotlin("test-junit"))
}
}
val jsMain by getting {
dependencies {
api(kotlin("stdlib-js"))
}
}
val jsTest by getting {
dependencies {
implementation(kotlin("test-js"))
} }
} }
// mingwMain { // mingwMain {

View File

@ -2,7 +2,6 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.jvm.JvmSynthetic
/** /**
* Basic implementation of Matrix space based on [NDStructure] * Basic implementation of Matrix space based on [NDStructure]
@ -25,7 +24,7 @@ class BufferMatrix<T : Any>(
override val colNum: Int, override val colNum: Int,
val buffer: Buffer<out T>, val buffer: Buffer<out T>,
override val features: Set<MatrixFeature> = emptySet() override val features: Set<MatrixFeature> = emptySet()
) : Matrix<T> { ) : FeaturedMatrix<T> {
init { init {
if (buffer.size != rowNum * colNum) { if (buffer.size != rowNum * colNum) {

View File

@ -8,6 +8,9 @@ import scientifik.kmath.structures.Buffer.Companion.boxing
import kotlin.math.sqrt import kotlin.math.sqrt
/**
* Basic operations on matrices. Operates on [Matrix]
*/
interface MatrixContext<T : Any> { interface MatrixContext<T : Any> {
/** /**
* Produce a matrix with this context and given dimensions * Produce a matrix with this context and given dimensions
@ -101,18 +104,18 @@ interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
operator fun Matrix<T>.times(number: Number): Matrix<T> = operator fun Matrix<T>.times(number: Number): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } } produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this operator fun Number.times(matrix: FeaturedMatrix<T>): Matrix<T> = matrix * this
override fun Matrix<T>.times(value: T): Matrix<T> = override fun Matrix<T>.times(value: T): Matrix<T> =
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
} }
/** /**
* Specialized 2-d structure * A 2d structure plus optional matrix-specific features
*/ */
interface Matrix<T : Any> : Structure2D<T> { interface FeaturedMatrix<T : Any> : Matrix<T> {
val rowNum: Int
val colNum: Int override val shape: IntArray get() = intArrayOf(rowNum, colNum)
val features: Set<MatrixFeature> val features: Set<MatrixFeature>
@ -122,70 +125,54 @@ interface Matrix<T : Any> : Structure2D<T> {
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
* add only those features that are valid. * add only those features that are valid.
*/ */
fun suggestFeature(vararg features: MatrixFeature): Matrix<T> fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
override fun get(index: IntArray): T = get(index[0], index[1])
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
val rows: Point<Point<T>>
get() = VirtualBuffer(rowNum) { i ->
VirtualBuffer(colNum) { j -> get(i, j) }
}
val columns: Point<Point<T>>
get() = VirtualBuffer(colNum) { j ->
VirtualBuffer(rowNum) { i -> get(i, j) }
}
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in (0 until rowNum)) {
for (j in (0 until colNum)) {
yield(intArrayOf(i, j) to get(i, j))
}
}
}
companion object { companion object {
fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
}
}
fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
MatrixContext.real.produce(rows, columns, initializer) MatrixContext.real.produce(rows, columns, initializer)
/** /**
* Build a square matrix from given elements. * Build a square matrix from given elements.
*/ */
fun <T : Any> square(vararg elements: T): Matrix<T> { fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
val size: Int = sqrt(elements.size.toDouble()).toInt() val size: Int = sqrt(elements.size.toDouble()).toInt()
if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square")
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(size, size, buffer) return BufferMatrix(size, size, buffer)
}
fun <T : Any> build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
}
} }
fun <T : Any> Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder<T> = MatrixBuilder(rows, columns)
class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) { class MatrixBuilder<T : Any>(val rows: Int, val columns: Int) {
operator fun invoke(vararg elements: T): Matrix<T> { operator fun invoke(vararg elements: T): FeaturedMatrix<T> {
if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns")
val buffer = elements.asBuffer() val buffer = elements.asBuffer()
return BufferMatrix(rows, columns, buffer) return BufferMatrix(rows, columns, buffer)
} }
} }
val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet()
/** /**
* Check if matrix has the given feature class * Check if matrix has the given feature class
*/ */
inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean = features.find { it is T } != null inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
features.find { it is T } != null
/** /**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/ */
inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInstance<T>().firstOrNull() inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
features.filterIsInstance<T>().firstOrNull()
/** /**
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created * Diagonal matrix of ones. The matrix is virtual no actual matrix is created
*/ */
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> = fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { i, j -> VirtualMatrix<T>(rows, columns) { i, j ->
if (i == j) elementContext.one else elementContext.zero if (i == j) elementContext.one else elementContext.zero
} }
@ -194,7 +181,7 @@ fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: In
/** /**
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> = fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero } VirtualMatrix<T>(rows, columns) { _, _ -> elementContext.zero }
class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature class TransposedFeature<T : Any>(val original: Matrix<T>) : MatrixFeature

View File

@ -2,12 +2,12 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.MutableBuffer import scientifik.kmath.structures.*
import scientifik.kmath.structures.MutableBuffer.Companion.boxing import scientifik.kmath.structures.MutableBuffer.Companion.boxing
import scientifik.kmath.structures.MutableBufferFactory
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.get
/**
* Common implementation of [LUPDecompositionFeature]
*/
class LUPDecomposition<T : Comparable<T>>( class LUPDecomposition<T : Comparable<T>>(
private val elementContext: Ring<T>, private val elementContext: Ring<T>,
internal val lu: NDStructure<T>, internal val lu: NDStructure<T>,
@ -20,7 +20,7 @@ class LUPDecomposition<T : Comparable<T>>(
* *
* L is a lower-triangular matrix with [Ring.one] in diagonal * L is a lower-triangular matrix with [Ring.one] in diagonal
*/ */
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
when { when {
j < i -> lu[i, j] j < i -> lu[i, j]
j == i -> elementContext.one j == i -> elementContext.one
@ -34,7 +34,7 @@ class LUPDecomposition<T : Comparable<T>>(
* *
* U is an upper-triangular matrix including the diagonal * U is an upper-triangular matrix including the diagonal
*/ */
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
if (j >= i) lu[i, j] else elementContext.zero if (j >= i) lu[i, j] else elementContext.zero
} }
@ -45,7 +45,7 @@ class LUPDecomposition<T : Comparable<T>>(
* P is a sparse matrix with exactly one element set to [Ring.one] in * P is a sparse matrix with exactly one element set to [Ring.one] in
* each row and each column, all other elements being set to [Ring.zero]. * each row and each column, all other elements being set to [Ring.zero].
*/ */
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
if (j == pivot[i]) elementContext.one else elementContext.zero if (j == pivot[i]) elementContext.one else elementContext.zero
} }
@ -62,7 +62,9 @@ class LUPDecomposition<T : Comparable<T>>(
} }
/**
* Common implementation of LUP [LinearSolver] based on commons-math code
*/
class LUSolver<T : Comparable<T>, F : Field<T>>( class LUSolver<T : Comparable<T>, F : Field<T>>(
val context: GenericMatrixContext<T, F>, val context: GenericMatrixContext<T, F>,
val bufferFactory: MutableBufferFactory<T> = ::boxing, val bufferFactory: MutableBufferFactory<T> = ::boxing,

View File

@ -3,6 +3,7 @@ package scientifik.kmath.linear
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Norm import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import scientifik.kmath.structures.Matrix
import scientifik.kmath.structures.VirtualBuffer import scientifik.kmath.structures.VirtualBuffer
import scientifik.kmath.structures.asSequence import scientifik.kmath.structures.asSequence
@ -12,7 +13,7 @@ import scientifik.kmath.structures.asSequence
*/ */
interface LinearSolver<T : Any> { interface LinearSolver<T : Any> {
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector() fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).asPoint()
fun inverse(a: Matrix<T>): Matrix<T> fun inverse(a: Matrix<T>): Matrix<T>
} }
@ -32,14 +33,14 @@ object VectorL2Norm : Norm<Point<out Number>, Double> {
typealias RealVector = Vector<Double, RealField> typealias RealVector = Vector<Double, RealField>
typealias RealMatrix = Matrix<Double> typealias RealMatrix = Matrix<Double>
/** /**
* Convert matrix to vector if it is possible * Convert matrix to vector if it is possible
*/ */
fun <T: Any> Matrix<T>.toVector(): Point<T> = fun <T : Any> Matrix<T>.asPoint(): Point<T> =
if (this.colNum == 1) { if (this.colNum == 1) {
VirtualBuffer(rowNum){ get(it, 0) } VirtualBuffer(rowNum) { get(it, 0) }
} else error("Can't convert matrix with more than one column to vector") } else {
error("Can't convert matrix with more than one column to vector")
}
fun <T: Any> Point<T>.toMatrix(): Matrix<T> = VirtualMatrix(size, 1) { i, _ -> get(i) } fun <T : Any> Point<T>.toMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) }

View File

@ -25,7 +25,7 @@ object UnitFeature : MatrixFeature
* Inverted matrix feature * Inverted matrix feature
*/ */
interface InverseMatrixFeature<T : Any> : MatrixFeature { interface InverseMatrixFeature<T : Any> : MatrixFeature {
val inverse: Matrix<T> val inverse: FeaturedMatrix<T>
} }
/** /**
@ -54,9 +54,9 @@ object UFeature: MatrixFeature
* TODO add documentation * TODO add documentation
*/ */
interface LUPDecompositionFeature<T : Any> : MatrixFeature { interface LUPDecompositionFeature<T : Any> : MatrixFeature {
val l: Matrix<T> val l: FeaturedMatrix<T>
val u: Matrix<T> val u: FeaturedMatrix<T>
val p: Matrix<T> val p: FeaturedMatrix<T>
} }
//TODO add sparse matrix feature //TODO add sparse matrix feature

View File

@ -1,11 +1,16 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.structures.Matrix
class VirtualMatrix<T : Any>( class VirtualMatrix<T : Any>(
override val rowNum: Int, override val rowNum: Int,
override val colNum: Int, override val colNum: Int,
override val features: Set<MatrixFeature> = emptySet(), override val features: Set<MatrixFeature> = emptySet(),
val generator: (i: Int, j: Int) -> T val generator: (i: Int, j: Int) -> T
) : Matrix<T> { ) : FeaturedMatrix<T> {
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
override fun get(i: Int, j: Int): T = generator(i, j) override fun get(i: Int, j: Int): T = generator(i, j)
override fun suggestFeature(vararg features: MatrixFeature) = override fun suggestFeature(vararg features: MatrixFeature) =
@ -13,7 +18,7 @@ class VirtualMatrix<T : Any>(
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other !is Matrix<*>) return false if (other !is FeaturedMatrix<*>) return false
if (rowNum != other.rowNum) return false if (rowNum != other.rowNum) return false
if (colNum != other.colNum) return false if (colNum != other.colNum) return false
@ -34,7 +39,7 @@ class VirtualMatrix<T : Any>(
/** /**
* Wrap a matrix adding additional features to it * Wrap a matrix adding additional features to it
*/ */
fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): Matrix<T> { fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
return if (matrix is VirtualMatrix) { return if (matrix is VirtualMatrix) {
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
} else { } else {

View File

@ -5,8 +5,7 @@ interface NDStructure<T> {
val shape: IntArray val shape: IntArray
val dimension val dimension get() = shape.size
get() = shape.size
operator fun get(index: IntArray): T operator fun get(index: IntArray): T

View File

@ -27,20 +27,6 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
Structure1DWrapper(this)
} else {
error("Can't create 1d-structure from ${shape.size}d-structure")
}
fun <T> NDBuffer<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
Buffer1DWrapper(this.buffer)
} else {
error("Can't create 1d-structure from ${shape.size}d-structure")
}
/** /**
* A structure wrapper for buffer * A structure wrapper for buffer
@ -56,39 +42,21 @@ private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T>
override fun get(index: Int): T = buffer.get(index) override fun get(index: Int): T = buffer.get(index)
} }
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
fun <T> NDStructure<T>.as1D(): Structure1D<T> = if (shape.size == 1) {
if( this is NDBuffer){
Buffer1DWrapper(this.buffer)
} else {
Structure1DWrapper(this)
}
} else {
error("Can't create 1d-structure from ${shape.size}d-structure")
}
/** /**
* Represent this buffer as 1D structure * Represent this buffer as 1D structure
*/ */
fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this) fun <T> Buffer<T>.asND(): Structure1D<T> = Buffer1DWrapper(this)
/**
* A structure that is guaranteed to be two-dimensional
*/
interface Structure2D<T> : NDStructure<T> {
operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T {
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
return get(index[0], index[1])
}
}
/**
* A 2D wrapper for nd-structure
*/
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
override fun get(i: Int, j: Int): T = structure[i, j]
override val shape: IntArray get() = structure.shape
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
Structure2DWrapper(this)
} else {
error("Can't create 2d-structure from ${shape.size}d-structure")
}

View File

@ -0,0 +1,79 @@
package scientifik.kmath.structures
/**
* A structure that is guaranteed to be two-dimensional
*/
interface Structure2D<T> : NDStructure<T> {
val rowNum: Int get() = shape[0]
val colNum: Int get() = shape[1]
operator fun get(i: Int, j: Int): T
override fun get(index: IntArray): T {
if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}")
return get(index[0], index[1])
}
val rows: Buffer<Buffer<T>>
get() = VirtualBuffer(rowNum) { i ->
VirtualBuffer(colNum) { j -> get(i, j) }
}
val columns: Buffer<Buffer<T>>
get() = VirtualBuffer(colNum) { j ->
VirtualBuffer(rowNum) { i -> get(i, j) }
}
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
for (i in (0 until rowNum)) {
for (j in (0 until colNum)) {
yield(intArrayOf(i, j) to get(i, j))
}
}
}
companion object {
}
}
/**
* A 2D wrapper for nd-structure
*/
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
override fun get(i: Int, j: Int): T = structure[i, j]
override val shape: IntArray get() = structure.shape
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
}
/**
* Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch
*/
fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2) {
Structure2DWrapper(this)
} else {
error("Can't create 2d-structure from ${shape.size}d-structure")
}
/**
* Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise.
*/
fun <T> Structure2D<T>.as1D() = if (colNum == 1) {
object : Structure1D<T> {
override fun get(index: Int): T = get(index, 0)
override val shape: IntArray get() = intArrayOf(rowNum)
override fun elements(): Sequence<Pair<IntArray, T>> = elements()
override val size: Int get() = rowNum
}
} else {
error("Can't convert matrix with more than one column to vector")
}
typealias Matrix<T> = Structure2D<T>

View File

@ -44,7 +44,7 @@ class MatrixTest {
@Test @Test
fun testBuilder() { fun testBuilder() {
val matrix = Matrix.build<Double>(2, 3)( val matrix = FeaturedMatrix.build<Double>(2, 3)(
1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 1.0, 2.0 0.0, 1.0, 2.0
) )

View File

@ -13,7 +13,7 @@ class RealLUSolverTest {
@Test @Test
fun testInvert() { fun testInvert() {
val matrix = Matrix.square( val matrix = FeaturedMatrix.square(
3.0, 1.0, 3.0, 1.0,
1.0, 3.0 1.0, 3.0
) )
@ -31,7 +31,7 @@ class RealLUSolverTest {
val inverted = LUSolver.real.inverse(decomposed) val inverted = LUSolver.real.inverse(decomposed)
val expected = Matrix.square( val expected = FeaturedMatrix.square(
0.375, -0.125, 0.375, -0.125,
-0.125, 0.375 -0.125, 0.375
) )

View File

@ -2,6 +2,7 @@ package scientifik.kmath.linear
import koma.extensions.fill import koma.extensions.fill
import koma.matrix.MatrixFactory import koma.matrix.MatrixFactory
import scientifik.kmath.structures.Matrix
class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>, class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>,
LinearSolver<T> { LinearSolver<T> {
@ -48,24 +49,25 @@ class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T
KomaMatrix(a.toKoma().origin.inv()) KomaMatrix(a.toKoma().origin.inv())
} }
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
Matrix<T> {
override val rowNum: Int get() = origin.numRows() override val rowNum: Int get() = origin.numRows()
override val colNum: Int get() = origin.numCols() override val colNum: Int get() = origin.numCols()
override val shape: IntArray get() = intArrayOf(origin.numRows(),origin.numCols())
override val features: Set<MatrixFeature> = features ?: setOf( override val features: Set<MatrixFeature> = features ?: setOf(
object : DeterminantFeature<T> { object : DeterminantFeature<T> {
override val determinant: T get() = origin.det() override val determinant: T get() = origin.det()
}, },
object : LUPDecompositionFeature<T> { object : LUPDecompositionFeature<T> {
private val lup by lazy { origin.LU() } private val lup by lazy { origin.LU() }
override val l: Matrix<T> get() = KomaMatrix(lup.second) override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
override val u: Matrix<T> get() = KomaMatrix(lup.third) override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
override val p: Matrix<T> get() = KomaMatrix(lup.first) override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
} }
) )
override fun suggestFeature(vararg features: MatrixFeature): Matrix<T> = override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
KomaMatrix(this.origin, this.features + features) KomaMatrix(this.origin, this.features + features)
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)

View File

@ -8,43 +8,9 @@ val ioVersion: String by rootProject.extra
kotlin { kotlin {
jvm() jvm()
js() js()
sourceSets {
val commonMain by getting {
dependencies {
api(kotlin("stdlib"))
}
}
val commonTest by getting {
dependencies {
implementation(kotlin("test-common"))
implementation(kotlin("test-annotations-common"))
}
}
val jvmMain by getting {
dependencies {
api(kotlin("stdlib-jdk8"))
}
}
val jvmTest by getting {
dependencies {
implementation(kotlin("test"))
implementation(kotlin("test-junit"))
}
}
val jsMain by getting {
dependencies {
api(kotlin("stdlib-js"))
}
}
val jsTest by getting {
dependencies {
implementation(kotlin("test-js"))
}
}
// mingwMain { // mingwMain {
// } // }
// mingwTest { // mingwTest {
// } // }
}
} }