From 14f05eb1e1d57e7e555a3a7e0e624785249e307c Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sun, 7 Apr 2019 12:27:49 +0300 Subject: [PATCH] Matrix refactoring --- .../kmath/linear/LinearAlgebraBenchmark.kt | 3 +- .../kmath/linear/MultiplicationBenchmark.kt | 1 + build.gradle.kts | 1 + .../scientifik/kmath/linear/CMMatrix.kt | 39 ++++----- kmath-core/build.gradle.kts | 28 ------- .../scientifik/kmath/linear/BufferMatrix.kt | 3 +- .../linear/{Matrix.kt => FeaturedMatrix.kt} | 79 ++++++++----------- .../kmath/linear/LUPDecomposition.kt | 18 +++-- .../scientifik/kmath/linear/LinearAlgrebra.kt | 17 ++-- .../scientifik/kmath/linear/MatrixFeatures.kt | 8 +- .../scientifik/kmath/linear/VirtualMatrix.kt | 11 ++- .../kmath/structures/NDStructure.kt | 3 +- ...pecializedStructures.kt => Structure1D.kt} | 60 ++++---------- .../kmath/structures/Structure2D.kt | 79 +++++++++++++++++++ .../scientifik/kmath/linear/MatrixTest.kt | 2 +- .../kmath/linear/RealLUSolverTest.kt | 4 +- .../scientifik.kmath.linear/KomaMatrix.kt | 14 ++-- kmath-memory/build.gradle.kts | 36 +-------- 18 files changed, 196 insertions(+), 210 deletions(-) rename kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/{Matrix.kt => FeaturedMatrix.kt} (75%) rename kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/{SpecializedStructures.kt => Structure1D.kt} (59%) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt index 45a6a774e..fda31b01e 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt @@ -1,6 +1,7 @@ package scientifik.kmath.linear import koma.matrix.ejml.EJMLMatrixFactory +import scientifik.kmath.structures.Matrix import kotlin.random.Random import kotlin.system.measureTimeMillis @@ -30,7 +31,7 @@ fun main() { //commons-math - val cmContext = CMMatrixContext + val cmContext = CMLUPSolver val commonsTime = measureTimeMillis { cmContext.run { diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt index 0b3b7c275..7b1bd9e7e 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/linear/MultiplicationBenchmark.kt @@ -1,6 +1,7 @@ package scientifik.kmath.linear import koma.matrix.ejml.EJMLMatrixFactory +import scientifik.kmath.structures.Matrix import kotlin.random.Random import kotlin.system.measureTimeMillis diff --git a/build.gradle.kts b/build.gradle.kts index 5ad331c48..964eb4998 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -193,6 +193,7 @@ subprojects { targets.all { sourceSets.all { languageSettings.progressiveMode = true + languageSettings.enableLanguageFeature("InlineClasses") } } diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt index 808b2768b..513dfe517 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt @@ -3,13 +3,14 @@ package scientifik.kmath.linear import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.RealMatrix import org.apache.commons.math3.linear.RealVector +import scientifik.kmath.structures.Matrix -class CMMatrix(val origin: RealMatrix, features: Set? = null) : Matrix { +class CMMatrix(val origin: RealMatrix, features: Set? = null) : FeaturedMatrix { override val rowNum: Int get() = origin.rowDimension override val colNum: Int get() = origin.columnDimension override val features: Set = features ?: sequence { - if(origin is DiagonalMatrix) yield(DiagonalFeature) + if (origin is DiagonalMatrix) yield(DiagonalFeature) }.toSet() override fun suggestFeature(vararg features: MatrixFeature) = @@ -45,28 +46,13 @@ fun Point.toCM(): CMVector = if (this is CMVector) { fun RealVector.toPoint() = CMVector(this) -object CMMatrixContext : MatrixContext, LinearSolver { +object CMMatrixContext : MatrixContext { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } } return CMMatrix(Array2DRowRealMatrix(array)) } - override fun solve(a: Matrix, b: Matrix): CMMatrix { - val decomposition = LUDecomposition(a.toCM().origin) - return decomposition.solver.solve(b.toCM().origin).toMatrix() - } - - override fun solve(a: Matrix, b: Point): CMVector { - val decomposition = LUDecomposition(a.toCM().origin) - return decomposition.solver.solve(b.toCM().origin).toPoint() - } - - override fun inverse(a: Matrix): CMMatrix { - val decomposition = LUDecomposition(a.toCM().origin) - return decomposition.solver.inverse.toMatrix() - } - override fun Matrix.dot(other: Matrix) = CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) @@ -87,6 +73,23 @@ object CMMatrixContext : MatrixContext, LinearSolver { CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble())) } +object CMLUPSolver: LinearSolver{ + override fun solve(a: Matrix, b: Matrix): CMMatrix { + val decomposition = LUDecomposition(a.toCM().origin) + return decomposition.solver.solve(b.toCM().origin).toMatrix() + } + + override fun solve(a: Matrix, b: Point): CMVector { + val decomposition = LUDecomposition(a.toCM().origin) + return decomposition.solver.solve(b.toCM().origin).toPoint() + } + + override fun inverse(a: Matrix): CMMatrix { + val decomposition = LUDecomposition(a.toCM().origin) + return decomposition.solver.inverse.toMatrix() + } +} + operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin)) operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin)) diff --git a/kmath-core/build.gradle.kts b/kmath-core/build.gradle.kts index bd125b373..ab6ae7822 100644 --- a/kmath-core/build.gradle.kts +++ b/kmath-core/build.gradle.kts @@ -13,34 +13,6 @@ kotlin { val commonMain by getting { dependencies { api(project(":kmath-memory")) - api(kotlin("stdlib")) - } - } - val commonTest by getting { - dependencies { - implementation(kotlin("test-common")) - implementation(kotlin("test-annotations-common")) - } - } - val jvmMain by getting { - dependencies { - api(kotlin("stdlib-jdk8")) - } - } - val jvmTest by getting { - dependencies { - implementation(kotlin("test")) - implementation(kotlin("test-junit")) - } - } - val jsMain by getting { - dependencies { - api(kotlin("stdlib-js")) - } - } - val jsTest by getting { - dependencies { - implementation(kotlin("test-js")) } } // mingwMain { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index dbffea8f3..5d116358d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -2,7 +2,6 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Ring import scientifik.kmath.structures.* -import kotlin.jvm.JvmSynthetic /** * Basic implementation of Matrix space based on [NDStructure] @@ -25,7 +24,7 @@ class BufferMatrix( override val colNum: Int, val buffer: Buffer, override val features: Set = emptySet() -) : Matrix { +) : FeaturedMatrix { init { if (buffer.size != rowNum * colNum) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt similarity index 75% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt index 329933019..62c360f26 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/FeaturedMatrix.kt @@ -8,6 +8,9 @@ import scientifik.kmath.structures.Buffer.Companion.boxing import kotlin.math.sqrt +/** + * Basic operations on matrices. Operates on [Matrix] + */ interface MatrixContext { /** * Produce a matrix with this context and given dimensions @@ -101,18 +104,18 @@ interface GenericMatrixContext> : MatrixContext { operator fun Matrix.times(number: Number): Matrix = produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } } - operator fun Number.times(matrix: Matrix): Matrix = matrix * this + operator fun Number.times(matrix: FeaturedMatrix): Matrix = matrix * this override fun Matrix.times(value: T): Matrix = produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } } /** - * Specialized 2-d structure + * A 2d structure plus optional matrix-specific features */ -interface Matrix : Structure2D { - val rowNum: Int - val colNum: Int +interface FeaturedMatrix : Matrix { + + override val shape: IntArray get() = intArrayOf(rowNum, colNum) val features: Set @@ -122,70 +125,54 @@ interface Matrix : Structure2D { * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to * add only those features that are valid. */ - fun suggestFeature(vararg features: MatrixFeature): Matrix - - override fun get(index: IntArray): T = get(index[0], index[1]) - - override val shape: IntArray get() = intArrayOf(rowNum, colNum) - - val rows: Point> - get() = VirtualBuffer(rowNum) { i -> - VirtualBuffer(colNum) { j -> get(i, j) } - } - - val columns: Point> - get() = VirtualBuffer(colNum) { j -> - VirtualBuffer(rowNum) { i -> get(i, j) } - } - - override fun elements(): Sequence> = sequence { - for (i in (0 until rowNum)) { - for (j in (0 until colNum)) { - yield(intArrayOf(i, j) to get(i, j)) - } - } - } + fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix companion object { - fun real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = - MatrixContext.real.produce(rows, columns, initializer) - /** - * Build a square matrix from given elements. - */ - fun square(vararg elements: T): Matrix { - val size: Int = sqrt(elements.size.toDouble()).toInt() - if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") - val buffer = elements.asBuffer() - return BufferMatrix(size, size, buffer) - } - - fun build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) } } +fun Structure2D.Companion.real(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = + MatrixContext.real.produce(rows, columns, initializer) + +/** + * Build a square matrix from given elements. + */ +fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix { + val size: Int = sqrt(elements.size.toDouble()).toInt() + if (size * size != elements.size) error("The number of elements ${elements.size} is not a full square") + val buffer = elements.asBuffer() + return BufferMatrix(size, size, buffer) +} + +fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) + class MatrixBuilder(val rows: Int, val columns: Int) { - operator fun invoke(vararg elements: T): Matrix { + operator fun invoke(vararg elements: T): FeaturedMatrix { if (rows * columns != elements.size) error("The number of elements ${elements.size} is not equal $rows * $columns") val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) } } +val Matrix<*>.features get() = (this as? FeaturedMatrix)?.features?: emptySet() + /** * Check if matrix has the given feature class */ -inline fun Matrix<*>.hasFeature(): Boolean = features.find { it is T } != null +inline fun Matrix<*>.hasFeature(): Boolean = + features.find { it is T } != null /** * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria */ -inline fun Matrix<*>.getFeature(): T? = features.filterIsInstance().firstOrNull() +inline fun Matrix<*>.getFeature(): T? = + features.filterIsInstance().firstOrNull() /** * Diagonal matrix of ones. The matrix is virtual no actual matrix is created */ -fun > GenericMatrixContext.one(rows: Int, columns: Int): Matrix = +fun > GenericMatrixContext.one(rows: Int, columns: Int): FeaturedMatrix = VirtualMatrix(rows, columns) { i, j -> if (i == j) elementContext.one else elementContext.zero } @@ -194,7 +181,7 @@ fun > GenericMatrixContext.one(rows: Int, columns: In /** * A virtual matrix of zeroes */ -fun > GenericMatrixContext.zero(rows: Int, columns: Int): Matrix = +fun > GenericMatrixContext.zero(rows: Int, columns: Int): FeaturedMatrix = VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } class TransposedFeature(val original: Matrix) : MatrixFeature diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index 85ddb8786..94ce8ed97 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -2,12 +2,12 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field import scientifik.kmath.operations.Ring -import scientifik.kmath.structures.MutableBuffer +import scientifik.kmath.structures.* import scientifik.kmath.structures.MutableBuffer.Companion.boxing -import scientifik.kmath.structures.MutableBufferFactory -import scientifik.kmath.structures.NDStructure -import scientifik.kmath.structures.get +/** + * Common implementation of [LUPDecompositionFeature] + */ class LUPDecomposition>( private val elementContext: Ring, internal val lu: NDStructure, @@ -20,7 +20,7 @@ class LUPDecomposition>( * * L is a lower-triangular matrix with [Ring.one] in diagonal */ - override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> + override val l: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> when { j < i -> lu[i, j] j == i -> elementContext.one @@ -34,7 +34,7 @@ class LUPDecomposition>( * * U is an upper-triangular matrix including the diagonal */ - override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> + override val u: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> if (j >= i) lu[i, j] else elementContext.zero } @@ -45,7 +45,7 @@ class LUPDecomposition>( * P is a sparse matrix with exactly one element set to [Ring.one] in * each row and each column, all other elements being set to [Ring.zero]. */ - override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val p: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j == pivot[i]) elementContext.one else elementContext.zero } @@ -62,7 +62,9 @@ class LUPDecomposition>( } - +/** + * Common implementation of LUP [LinearSolver] based on commons-math code + */ class LUSolver, F : Field>( val context: GenericMatrixContext, val bufferFactory: MutableBufferFactory = ::boxing, diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index 51f61e030..4254fd7ea 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -3,6 +3,7 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField +import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.VirtualBuffer import scientifik.kmath.structures.asSequence @@ -10,9 +11,9 @@ import scientifik.kmath.structures.asSequence /** * A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors */ -interface LinearSolver { +interface LinearSolver { fun solve(a: Matrix, b: Matrix): Matrix - fun solve(a: Matrix, b: Point): Point = solve(a, b.toMatrix()).toVector() + fun solve(a: Matrix, b: Point): Point = solve(a, b.toMatrix()).asPoint() fun inverse(a: Matrix): Matrix } @@ -32,14 +33,14 @@ object VectorL2Norm : Norm, Double> { typealias RealVector = Vector typealias RealMatrix = Matrix - - /** * Convert matrix to vector if it is possible */ -fun Matrix.toVector(): Point = +fun Matrix.asPoint(): Point = if (this.colNum == 1) { - VirtualBuffer(rowNum){ get(it, 0) } - } else error("Can't convert matrix with more than one column to vector") + VirtualBuffer(rowNum) { get(it, 0) } + } else { + error("Can't convert matrix with more than one column to vector") + } -fun Point.toMatrix(): Matrix = VirtualMatrix(size, 1) { i, _ -> get(i) } \ No newline at end of file +fun Point.toMatrix() = VirtualMatrix(size, 1) { i, _ -> get(i) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt index 6b45a14b1..de315071f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/MatrixFeatures.kt @@ -25,7 +25,7 @@ object UnitFeature : MatrixFeature * Inverted matrix feature */ interface InverseMatrixFeature : MatrixFeature { - val inverse: Matrix + val inverse: FeaturedMatrix } /** @@ -54,9 +54,9 @@ object UFeature: MatrixFeature * TODO add documentation */ interface LUPDecompositionFeature : MatrixFeature { - val l: Matrix - val u: Matrix - val p: Matrix + val l: FeaturedMatrix + val u: FeaturedMatrix + val p: FeaturedMatrix } //TODO add sparse matrix feature \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt index 1bab52902..951471a07 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/VirtualMatrix.kt @@ -1,11 +1,16 @@ package scientifik.kmath.linear +import scientifik.kmath.structures.Matrix + class VirtualMatrix( override val rowNum: Int, override val colNum: Int, override val features: Set = emptySet(), val generator: (i: Int, j: Int) -> T -) : Matrix { +) : FeaturedMatrix { + + override val shape: IntArray get() = intArrayOf(rowNum, colNum) + override fun get(i: Int, j: Int): T = generator(i, j) override fun suggestFeature(vararg features: MatrixFeature) = @@ -13,7 +18,7 @@ class VirtualMatrix( override fun equals(other: Any?): Boolean { if (this === other) return true - if (other !is Matrix<*>) return false + if (other !is FeaturedMatrix<*>) return false if (rowNum != other.rowNum) return false if (colNum != other.colNum) return false @@ -34,7 +39,7 @@ class VirtualMatrix( /** * Wrap a matrix adding additional features to it */ - fun wrap(matrix: Matrix, vararg features: MatrixFeature): Matrix { + fun wrap(matrix: Matrix, vararg features: MatrixFeature): FeaturedMatrix { return if (matrix is VirtualMatrix) { VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) } else { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 2cf9af6fb..5895251e8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -5,8 +5,7 @@ interface NDStructure { val shape: IntArray - val dimension - get() = shape.size + val dimension get() = shape.size operator fun get(index: IntArray): T diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/SpecializedStructures.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt similarity index 59% rename from kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/SpecializedStructures.kt rename to kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt index 734aba5ac..df56017a3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/SpecializedStructures.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure1D.kt @@ -27,20 +27,6 @@ private inline class Structure1DWrapper(val structure: NDStructure) : Stru override fun elements(): Sequence> = structure.elements() } -/** - * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch - */ -fun NDStructure.as1D(): Structure1D = if (shape.size == 1) { - Structure1DWrapper(this) -} else { - error("Can't create 1d-structure from ${shape.size}d-structure") -} - -fun NDBuffer.as1D(): Structure1D = if (shape.size == 1) { - Buffer1DWrapper(this.buffer) -} else { - error("Can't create 1d-structure from ${shape.size}d-structure") -} /** * A structure wrapper for buffer @@ -56,39 +42,21 @@ private inline class Buffer1DWrapper(val buffer: Buffer) : Structure1D override fun get(index: Int): T = buffer.get(index) } -/** - * Represent this buffer as 1D structure - */ -fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) - -/** - * A structure that is guaranteed to be two-dimensional - */ -interface Structure2D : NDStructure { - operator fun get(i: Int, j: Int): T - - override fun get(index: IntArray): T { - if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}") - return get(index[0], index[1]) - } -} - -/** - * A 2D wrapper for nd-structure - */ -private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { - override fun get(i: Int, j: Int): T = structure[i, j] - - override val shape: IntArray get() = structure.shape - - override fun elements(): Sequence> = structure.elements() -} - /** * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch */ -fun NDStructure.as2D(): Structure2D = if (shape.size == 2) { - Structure2DWrapper(this) +fun NDStructure.as1D(): Structure1D = if (shape.size == 1) { + if( this is NDBuffer){ + Buffer1DWrapper(this.buffer) + } else { + Structure1DWrapper(this) + } } else { - error("Can't create 2d-structure from ${shape.size}d-structure") -} \ No newline at end of file + error("Can't create 1d-structure from ${shape.size}d-structure") +} + + +/** + * Represent this buffer as 1D structure + */ +fun Buffer.asND(): Structure1D = Buffer1DWrapper(this) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt new file mode 100644 index 000000000..e736f84a0 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Structure2D.kt @@ -0,0 +1,79 @@ +package scientifik.kmath.structures + +/** + * A structure that is guaranteed to be two-dimensional + */ +interface Structure2D : NDStructure { + val rowNum: Int get() = shape[0] + val colNum: Int get() = shape[1] + + operator fun get(i: Int, j: Int): T + + override fun get(index: IntArray): T { + if (index.size != 2) error("Index dimension mismatch. Expected 2 but found ${index.size}") + return get(index[0], index[1]) + } + + + val rows: Buffer> + get() = VirtualBuffer(rowNum) { i -> + VirtualBuffer(colNum) { j -> get(i, j) } + } + + val columns: Buffer> + get() = VirtualBuffer(colNum) { j -> + VirtualBuffer(rowNum) { i -> get(i, j) } + } + + override fun elements(): Sequence> = sequence { + for (i in (0 until rowNum)) { + for (j in (0 until colNum)) { + yield(intArrayOf(i, j) to get(i, j)) + } + } + } + + companion object { + + } +} + +/** + * A 2D wrapper for nd-structure + */ +private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { + override fun get(i: Int, j: Int): T = structure[i, j] + + override val shape: IntArray get() = structure.shape + + override fun elements(): Sequence> = structure.elements() +} + +/** + * Represent a [NDStructure] as [Structure1D]. Throw error in case of dimension mismatch + */ +fun NDStructure.as2D(): Structure2D = if (shape.size == 2) { + Structure2DWrapper(this) +} else { + error("Can't create 2d-structure from ${shape.size}d-structure") +} + +/** + * Represent this 2D structure as 1D if it has exactly one column. Throw error otherwise. + */ +fun Structure2D.as1D() = if (colNum == 1) { + object : Structure1D { + override fun get(index: Int): T = get(index, 0) + + override val shape: IntArray get() = intArrayOf(rowNum) + + override fun elements(): Sequence> = elements() + + override val size: Int get() = rowNum + } +} else { + error("Can't convert matrix with more than one column to vector") +} + + +typealias Matrix = Structure2D \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index 61aa506c4..8d09b6d7b 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -44,7 +44,7 @@ class MatrixTest { @Test fun testBuilder() { - val matrix = Matrix.build(2, 3)( + val matrix = FeaturedMatrix.build(2, 3)( 1.0, 0.0, 0.0, 0.0, 1.0, 2.0 ) diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index bfa720369..8887d3a32 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -13,7 +13,7 @@ class RealLUSolverTest { @Test fun testInvert() { - val matrix = Matrix.square( + val matrix = FeaturedMatrix.square( 3.0, 1.0, 1.0, 3.0 ) @@ -31,7 +31,7 @@ class RealLUSolverTest { val inverted = LUSolver.real.inverse(decomposed) - val expected = Matrix.square( + val expected = FeaturedMatrix.square( 0.375, -0.125, -0.125, 0.375 ) diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt index 343d5b8b9..b95f04887 100644 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt @@ -2,6 +2,7 @@ package scientifik.kmath.linear import koma.extensions.fill import koma.matrix.MatrixFactory +import scientifik.kmath.structures.Matrix class KomaMatrixContext(val factory: MatrixFactory>) : MatrixContext, LinearSolver { @@ -48,24 +49,25 @@ class KomaMatrixContext(val factory: MatrixFactory(val origin: koma.matrix.Matrix, features: Set? = null) : - Matrix { +class KomaMatrix(val origin: koma.matrix.Matrix, features: Set? = null) : FeaturedMatrix { override val rowNum: Int get() = origin.numRows() override val colNum: Int get() = origin.numCols() + override val shape: IntArray get() = intArrayOf(origin.numRows(),origin.numCols()) + override val features: Set = features ?: setOf( object : DeterminantFeature { override val determinant: T get() = origin.det() }, object : LUPDecompositionFeature { private val lup by lazy { origin.LU() } - override val l: Matrix get() = KomaMatrix(lup.second) - override val u: Matrix get() = KomaMatrix(lup.third) - override val p: Matrix get() = KomaMatrix(lup.first) + override val l: FeaturedMatrix get() = KomaMatrix(lup.second) + override val u: FeaturedMatrix get() = KomaMatrix(lup.third) + override val p: FeaturedMatrix get() = KomaMatrix(lup.first) } ) - override fun suggestFeature(vararg features: MatrixFeature): Matrix = + override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix = KomaMatrix(this.origin, this.features + features) override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) diff --git a/kmath-memory/build.gradle.kts b/kmath-memory/build.gradle.kts index f6d10875c..03a05c9fe 100644 --- a/kmath-memory/build.gradle.kts +++ b/kmath-memory/build.gradle.kts @@ -8,43 +8,9 @@ val ioVersion: String by rootProject.extra kotlin { jvm() js() - - sourceSets { - val commonMain by getting { - dependencies { - api(kotlin("stdlib")) - } - } - val commonTest by getting { - dependencies { - implementation(kotlin("test-common")) - implementation(kotlin("test-annotations-common")) - } - } - val jvmMain by getting { - dependencies { - api(kotlin("stdlib-jdk8")) - } - } - val jvmTest by getting { - dependencies { - implementation(kotlin("test")) - implementation(kotlin("test-junit")) - } - } - val jsMain by getting { - dependencies { - api(kotlin("stdlib-js")) - } - } - val jsTest by getting { - dependencies { - implementation(kotlin("test-js")) - } - } // mingwMain { // } // mingwTest { // } - } + } \ No newline at end of file