From baae18b22304d6484924aba165eff39b2658ae63 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 21 Jan 2019 15:05:44 +0300 Subject: [PATCH] Koma integration + multiple fixes to Matrix API --- benchmarks/build.gradle | 7 +- .../kmath/linear/LinearAlgebraBenchmark.kt | 2 +- .../scientifik/kmath/linear/CMMatrix.kt | 65 +++++++------ .../scientifik/kmath/linear/BufferMatrix.kt | 2 +- .../kmath/linear/LUPDecomposition.kt | 4 +- .../scientifik/kmath/linear/LinearAlgrebra.kt | 3 +- .../kotlin/scientifik/kmath/linear/Matrix.kt | 92 +++++++++++-------- .../scientifik/kmath/linear/MatrixTest.kt | 4 +- .../kmath/linear/RealLUSolverTest.kt | 4 +- .../scientifik.kmath.linear/KomaMatrix.kt | 71 +++++++++++--- 10 files changed, 159 insertions(+), 95 deletions(-) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index ef9264d32..98a4a64c5 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -5,9 +5,10 @@ plugins { } dependencies { - api project(":kmath-core") - api project(":kmath-coroutines") - api project(":kmath-commons") + implementation project(":kmath-core") + implementation project(":kmath-coroutines") + implementation project(":kmath-commons") + implementation project(":kmath-koma") implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8" //jmh project(':kmath-core') } diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt index 0cb7f15df..16ffc4639 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/linear/LinearAlgebraBenchmark.kt @@ -29,7 +29,7 @@ fun main() { //commons-math - val cmSolver = CMSolver + val cmSolver = CMMatrixContext val commonsTime = measureTimeMillis { val cm = matrix.toCM() //avoid overhead on conversion diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt index 15c7489ab..aed4f4bb1 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/linear/CMMatrix.kt @@ -3,8 +3,6 @@ package scientifik.kmath.linear import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.RealMatrix import org.apache.commons.math3.linear.RealVector -import scientifik.kmath.operations.RealField -import scientifik.kmath.structures.DoubleBuffer inline class CMMatrix(val origin: RealMatrix) : Matrix { override val rowNum: Int get() = origin.rowDimension @@ -40,44 +38,51 @@ fun Point.toCM(): CMVector = if (this is CMVector) { CMVector(ArrayRealVector(array)) } -fun RealVector.toPoint() = DoubleBuffer(this.toArray()) +fun RealVector.toPoint() = CMVector(this) -object CMMatrixContext : MatrixContext { - override val elementContext: RealField get() = RealField +object CMMatrixContext : MatrixContext, LinearSolver { - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix { + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix { val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } } return CMMatrix(Array2DRowRealMatrix(array)) } - override fun point(size: Int, initializer: (Int) -> Double): Point { - val array = DoubleArray(size, initializer) - return CMVector(ArrayRealVector(array)) + override fun solve(a: Matrix, b: Matrix): CMMatrix { + val decomposition = LUDecomposition(a.toCM().origin) + return decomposition.solver.solve(b.toCM().origin).toMatrix() } - override fun Matrix.dot(other: Matrix): Matrix = CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) + override fun solve(a: Matrix, b: Point): CMVector { + val decomposition = LUDecomposition(a.toCM().origin) + return decomposition.solver.solve(b.toCM().origin).toPoint() + } + + override fun inverse(a: Matrix): CMMatrix { + val decomposition = LUDecomposition(a.toCM().origin) + return decomposition.solver.inverse.toMatrix() + } + + override fun Matrix.dot(other: Matrix) = + CMMatrix(this.toCM().origin.multiply(other.toCM().origin)) + + override fun Matrix.dot(vector: Point): CMVector = + CMVector(this.toCM().origin.preMultiply(vector.toCM().origin)) + + + override fun Matrix.unaryMinus(): CMMatrix = + produce(rowNum, colNum) { i, j -> -get(i, j) } + + override fun Matrix.plus(b: Matrix) = + CMMatrix(this.toCM().origin.multiply(b.toCM().origin)) + + override fun Matrix.minus(b: Matrix) = + CMMatrix(this.toCM().origin.subtract(b.toCM().origin)) + + override fun Matrix.times(value: Double) = + CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble())) } operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin)) operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin)) -infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin)) - -object CMSolver : LinearSolver { - - override fun solve(a: Matrix, b: Matrix): Matrix { - val decomposition = LUDecomposition(a.toCM().origin) - return decomposition.solver.solve(b.toCM().origin).toMatrix() - } - - override fun solve(a: Matrix, b: Point): Point { - val decomposition = LUDecomposition(a.toCM().origin) - return decomposition.solver.solve(b.toCM().origin).toPoint() - } - - override fun inverse(a: Matrix): Matrix { - val decomposition = LUDecomposition(a.toCM().origin) - return decomposition.solver.inverse.toMatrix() - } - -} \ No newline at end of file +infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin)) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt index 1b23b553d..5b65821db 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/BufferMatrix.kt @@ -9,7 +9,7 @@ import scientifik.kmath.structures.* class BufferMatrixContext>( override val elementContext: R, private val bufferFactory: BufferFactory -) : MatrixContext { +) : GenericMatrixContext { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix { val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt index 79973bd37..d2de7609d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUPDecomposition.kt @@ -65,7 +65,7 @@ class LUPDecomposition>( class LUSolver, F : Field>( - val context: MatrixContext, + val context: GenericMatrixContext, val bufferFactory: MutableBufferFactory = ::boxing, val singularityCheck: (T) -> Boolean ) : LinearSolver { @@ -201,6 +201,6 @@ class LUSolver, F : Field>( override fun inverse(a: Matrix): Matrix = solve(a, context.one(a.rowNum, a.colNum)) companion object { - val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 } + val real = LUSolver(GenericMatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 } } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index 4c56b1e63..51f61e030 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -3,7 +3,6 @@ package scientifik.kmath.linear import scientifik.kmath.operations.Field import scientifik.kmath.operations.Norm import scientifik.kmath.operations.RealField -import scientifik.kmath.operations.Ring import scientifik.kmath.structures.VirtualBuffer import scientifik.kmath.structures.asSequence @@ -11,7 +10,7 @@ import scientifik.kmath.structures.asSequence /** * A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors */ -interface LinearSolver> { +interface LinearSolver { fun solve(a: Matrix, b: Matrix): Matrix fun solve(a: Matrix, b: Point): Point = solve(a, b.toMatrix()).toVector() fun inverse(a: Matrix): Matrix diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 0c606e293..5b4f3f12d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -8,28 +8,61 @@ import scientifik.kmath.structures.Buffer.Companion.boxing import kotlin.math.sqrt -interface MatrixContext> { +interface MatrixContext { + /** + * Produce a matrix with this context and given dimensions + */ + fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix + + infix fun Matrix.dot(other: Matrix): Matrix + + infix fun Matrix.dot(vector: Point): Point + + operator fun Matrix.unaryMinus(): Matrix + + operator fun Matrix.plus(b: Matrix): Matrix + + operator fun Matrix.minus(b: Matrix): Matrix + + operator fun Matrix.times(value: T): Matrix + + operator fun T.times(m: Matrix): Matrix = m * this + + companion object { + /** + * Non-boxing double matrix + */ + val real = BufferMatrixContext(RealField, DoubleBufferFactory) + + /** + * A structured matrix with custom buffer + */ + fun > buffered( + ring: R, + bufferFactory: BufferFactory = ::boxing + ): GenericMatrixContext = + BufferMatrixContext(ring, bufferFactory) + + /** + * Automatic buffered matrix, unboxed if it is possible + */ + inline fun > auto(ring: R): GenericMatrixContext = + buffered(ring, Buffer.Companion::auto) + } +} + +interface GenericMatrixContext> : MatrixContext { /** * The ring context for matrix elements */ val elementContext: R - /** - * Produce a matrix with this context and given dimensions - */ - fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix - /** * Produce a point compatible with matrix space */ fun point(size: Int, initializer: (Int) -> T): Point - fun scale(a: Matrix, k: Number): Matrix { - //TODO create a special wrapper class for scaled matrices - return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } } - } - - infix fun Matrix.dot(other: Matrix): Matrix { + override infix fun Matrix.dot(other: Matrix): Matrix { //TODO add typed error if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})") return produce(rowNum, other.colNum) { i, j -> @@ -41,7 +74,7 @@ interface MatrixContext> { } } - infix fun Matrix.dot(vector: Point): Point { + override infix fun Matrix.dot(vector: Point): Point { //TODO add typed error if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})") return point(rowNum) { i -> @@ -52,15 +85,15 @@ interface MatrixContext> { } } - operator fun Matrix.unaryMinus() = + override operator fun Matrix.unaryMinus() = produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } } - operator fun Matrix.plus(b: Matrix): Matrix { + override operator fun Matrix.plus(b: Matrix): Matrix { if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]") return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } } - operator fun Matrix.minus(b: Matrix): Matrix { + override operator fun Matrix.minus(b: Matrix): Matrix { if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]") return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } } } @@ -68,27 +101,10 @@ interface MatrixContext> { operator fun Matrix.times(number: Number): Matrix = produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } } - operator fun Number.times(m: Matrix): Matrix = m * this + operator fun Number.times(matrix: Matrix): Matrix = matrix * this - - companion object { - /** - * Non-boxing double matrix - */ - val real = BufferMatrixContext(RealField, DoubleBufferFactory) - - /** - * A structured matrix with custom buffer - */ - fun > buffered(ring: R, bufferFactory: BufferFactory = ::boxing): MatrixContext = - BufferMatrixContext(ring, bufferFactory) - - /** - * Automatic buffered matrix, unboxed if it is possible - */ - inline fun > auto(ring: R): MatrixContext = - buffered(ring, Buffer.Companion::auto) - } + override fun Matrix.times(value: T): Matrix = + produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } } } /** @@ -173,7 +189,7 @@ inline fun Matrix<*>.getFeature(): T? = features.filterIsInsta /** * Diagonal matrix of ones. The matrix is virtual no actual matrix is created */ -fun > MatrixContext.one(rows: Int, columns: Int): Matrix = +fun > GenericMatrixContext.one(rows: Int, columns: Int): Matrix = VirtualMatrix(rows, columns) { i, j -> if (i == j) elementContext.one else elementContext.zero } @@ -182,7 +198,7 @@ fun > MatrixContext.one(rows: Int, columns: Int): Mat /** * A virtual matrix of zeroes */ -fun > MatrixContext.zero(rows: Int, columns: Int): Matrix = +fun > GenericMatrixContext.zero(rows: Int, columns: Int): Matrix = VirtualMatrix(rows, columns) { i, j -> elementContext.zero } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt index 101aae3b4..eaf00c59b 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/MatrixTest.kt @@ -22,7 +22,7 @@ class MatrixTest { @Test fun testTranspose() { - val matrix = MatrixContext.real.one(3, 3) + val matrix = GenericMatrixContext.real.one(3, 3) val transposed = matrix.transpose() assertEquals((matrix as BufferMatrix).buffer, (transposed as BufferMatrix).buffer) assertEquals(matrix, transposed) @@ -36,7 +36,7 @@ class MatrixTest { val matrix1 = vector1.toMatrix() val matrix2 = vector2.toMatrix().transpose() - val product = MatrixContext.real.run { matrix1 dot matrix2 } + val product = GenericMatrixContext.real.run { matrix1 dot matrix2 } assertEquals(5.0, product[1, 0]) diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index ea104355c..32b07eb86 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -6,7 +6,7 @@ import kotlin.test.assertEquals class RealLUSolverTest { @Test fun testInvertOne() { - val matrix = MatrixContext.real.one(2, 2) + val matrix = GenericMatrixContext.real.one(2, 2) val inverted = LUSolver.real.inverse(matrix) assertEquals(matrix, inverted) } @@ -25,7 +25,7 @@ class RealLUSolverTest { assertEquals(8.0, decomposition.determinant) //Check decomposition - with(MatrixContext.real) { + with(GenericMatrixContext.real) { assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u) } diff --git a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt index 18839ec88..3b7894d18 100644 --- a/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt +++ b/kmath-koma/src/commonMain/kotlin/scientifik.kmath.linear/KomaMatrix.kt @@ -1,27 +1,70 @@ package scientifik.kmath.linear -import scientifik.kmath.operations.Ring +import koma.extensions.fill +import koma.matrix.MatrixFactory -class KomaMatrixContext> : MatrixContext { - override val elementContext: R - get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates. +class KomaMatrixContext(val factory: MatrixFactory>) : MatrixContext, + LinearSolver { - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) = + KomaMatrix(factory.zeros(rows, columns).fill(initializer)) + + fun Matrix.toKoma(): KomaMatrix = if (this is KomaMatrix) { + this + } else { + produce(rowNum, colNum) { i, j -> get(i, j) } } - override fun point(size: Int, initializer: (Int) -> T): Point { - TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + fun Point.toKoma(): KomaVector = if (this is KomaVector) { + this + } else { + KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) }) } + + override fun Matrix.dot(other: Matrix) = + KomaMatrix(this.toKoma().origin * other.toKoma().origin) + + override fun Matrix.dot(vector: Point) = + KomaVector(this.toKoma().origin * vector.toKoma().origin) + + override fun Matrix.unaryMinus() = + KomaMatrix(this.toKoma().origin.unaryMinus()) + + override fun Matrix.plus(b: Matrix) = + KomaMatrix(this.toKoma().origin + b.toKoma().origin) + + override fun Matrix.minus(b: Matrix) = + KomaMatrix(this.toKoma().origin - b.toKoma().origin) + + override fun Matrix.times(value: T) = + KomaMatrix(this.toKoma().origin * value) + + + override fun solve(a: Matrix, b: Matrix) = + KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin)) + + override fun inverse(a: Matrix) = + KomaMatrix(a.toKoma().origin.inv()) } -inline class KomaMatrix(val matrix: koma.matrix.Matrix) : Matrix { - override val rowNum: Int get() = matrix.numRows() - override val colNum: Int get() = matrix.numCols() +inline class KomaMatrix(val origin: koma.matrix.Matrix) : Matrix { + override val rowNum: Int get() = origin.numRows() + override val colNum: Int get() = origin.numCols() override val features: Set get() = emptySet() - @Suppress("OVERRIDE_BY_INLINE") - override inline fun get(i: Int, j: Int): T = matrix.getGeneric(i, j) + override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) +} + +class KomaVector internal constructor(val origin: koma.matrix.Matrix) : Point { + init { + if (origin.numCols() != 1) error("Only single column matrices are allowed") + } + + override val size: Int get() = origin.numRows() + + override fun get(index: Int): T = origin.getGeneric(index) + + override fun iterator(): Iterator = origin.toIterable().iterator() +} -} \ No newline at end of file