From d5ba816b7daa64a089c3b07e4aa28ab1e10b74f8 Mon Sep 17 00:00:00 2001 From: darksnake Date: Sat, 13 Mar 2021 18:00:47 +0300 Subject: [PATCH] WIP Matrix refactor --- .../kscience/kmath/linear/LinearSpace.kt | 2 +- .../kscience/kmath/linear/LupDecomposition.kt | 4 +- .../space/kscience/kmath/linear/MatrixTest.kt | 3 +- .../kscience/kmath/dimensions/Wrappers.kt | 31 +++++++------ .../kscience/kmath/ejml/EjmlLinearSpace.kt | 39 ++++++---------- .../space/kscience/kmath/real/RealMatrix.kt | 29 ++++++------ .../kotlin/space/kscience/kmath/real/dot.kt | 46 +++++++++---------- 7 files changed, 74 insertions(+), 80 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 5800bdd0d..4a9d2c9b3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -22,7 +22,7 @@ public typealias Vector = Point * @param T the type of items in the matrices. * @param M the type of operated matrices. */ -public interface LinearSpace> { +public interface LinearSpace> { public val elementAlgebra: A /** diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt index 5d68534bc..0937f1e19 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt @@ -228,7 +228,7 @@ public inline fun > LinearSpace>.inverseWi @OptIn(UnstableKMathAPI::class) -public fun RealLinearSpace.solveWithLup(a: Matrix, b: Matrix): Matrix { +public fun LinearSpace.solveWithLup(a: Matrix, b: Matrix): Matrix { // Use existing decomposition if it is provided by matrix val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real val decomposition: LupDecomposition = a.getFeature() ?: lup(bufferFactory, a) { it < 1e-11 } @@ -238,5 +238,5 @@ public fun RealLinearSpace.solveWithLup(a: Matrix, b: Matrix): M /** * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. */ -public fun RealLinearSpace.inverseWithLup(matrix: Matrix): Matrix = +public fun LinearSpace.inverseWithLup(matrix: Matrix): Matrix = solveWithLup(matrix, one(matrix.rowNum, matrix.colNum)) \ No newline at end of file diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/MatrixTest.kt index 5cf83889a..27374e93f 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/linear/MatrixTest.kt @@ -2,7 +2,6 @@ package space.kscience.kmath.linear import space.kscience.kmath.nd.NDStructure import space.kscience.kmath.nd.as2D -import space.kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals @@ -17,7 +16,7 @@ class MatrixTest { @Test fun testBuilder() { - val matrix = Matrix.build(2, 3)( + val matrix = LinearSpace.real.matrix(2, 3)( 1.0, 0.0, 0.0, 0.0, 1.0, 2.0 ) diff --git a/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt index 89ff97f14..f7e14b29f 100644 --- a/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/space/kscience/kmath/dimensions/Wrappers.kt @@ -2,7 +2,8 @@ package space.kscience.kmath.dimensions import space.kscience.kmath.linear.* import space.kscience.kmath.nd.Structure2D -import space.kscience.kmath.operations.invoke +import space.kscience.kmath.operations.RealField +import space.kscience.kmath.operations.Ring /** * A matrix with compile-time controlled dimension @@ -77,7 +78,7 @@ public inline class DPointWrapper(public val point: Point) /** * Basic operations on dimension-safe matrices. Operates on [Matrix] */ -public inline class DMatrixContext(public val context: LinearSpace>) { +public inline class DMatrixContext>(public val context: LinearSpace) { public inline fun Matrix.coerce(): DMatrix { require(rowNum == Dimension.dim().toInt()) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" @@ -93,13 +94,15 @@ public inline class DMatrixContext(public val context: LinearSpace produce(noinline initializer: (i: Int, j: Int) -> T): DMatrix { + public inline fun produce( + noinline initializer: A.(i: Int, j: Int) -> T + ): DMatrix { val rows = Dimension.dim() val cols = Dimension.dim() return context.buildMatrix(rows.toInt(), cols.toInt(), initializer).coerce() } - public inline fun point(noinline initializer: (Int) -> T): DPoint { + public inline fun point(noinline initializer: A.(Int) -> T): DPoint { val size = Dimension.dim() return DPoint.coerceUnsafe( @@ -112,31 +115,31 @@ public inline class DMatrixContext(public val context: LinearSpace DMatrix.dot( other: DMatrix, - ): DMatrix = context { this@dot dot other }.coerce() + ): DMatrix = context.run { this@dot dot other }.coerce() public inline infix fun DMatrix.dot(vector: DPoint): DPoint = - DPoint.coerceUnsafe(context { this@dot dot vector }) + DPoint.coerceUnsafe(context.run { this@dot dot vector }) public inline operator fun DMatrix.times(value: T): DMatrix = - context { this@times.times(value) }.coerce() + context.run { this@times.times(value) }.coerce() public inline operator fun T.times(m: DMatrix): DMatrix = m * this public inline operator fun DMatrix.plus(other: DMatrix): DMatrix = - context { this@plus + other }.coerce() + context.run { this@plus + other }.coerce() public inline operator fun DMatrix.minus(other: DMatrix): DMatrix = - context { this@minus + other }.coerce() + context.run { this@minus + other }.coerce() public inline operator fun DMatrix.unaryMinus(): DMatrix = - context { this@unaryMinus.unaryMinus() }.coerce() + context.run { this@unaryMinus.unaryMinus() }.coerce() public inline fun DMatrix.transpose(): DMatrix = - context { (this@transpose as Matrix).transpose() }.coerce() + context.run { (this@transpose as Matrix).transpose() }.coerce() public companion object { - public val real: DMatrixContext = DMatrixContext(LinearSpace.real) + public val real: DMatrixContext = DMatrixContext(LinearSpace.real) } } @@ -144,11 +147,11 @@ public inline class DMatrixContext(public val context: LinearSpace DMatrixContext.one(): DMatrix = produce { i, j -> +public inline fun DMatrixContext.one(): DMatrix = produce { i, j -> if (i == j) 1.0 else 0.0 } -public inline fun DMatrixContext.zero(): DMatrix = +public inline fun DMatrixContext.zero(): DMatrix = produce { _, _ -> 0.0 } \ No newline at end of file diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt index 339e06459..57ce977fd 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt @@ -44,7 +44,7 @@ public object EjmlLinearSpace : LinearSpace, ScaleOperations< override fun buildVector(size: Int, initializer: RealField.(Int) -> Double): Vector = EjmlVector(SimpleMatrix(size, 1).also { - (0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) } + (0 until it.numRows()).forEach { row -> it[row, 0] = RealField.initializer(row) } }) private fun SimpleMatrix.wrapMatrix() = EjmlMatrix(this) @@ -59,43 +59,34 @@ public object EjmlLinearSpace : LinearSpace, ScaleOperations< EjmlVector(toEjml().origin.mult(vector.toEjml().origin)) public override operator fun Matrix.minus(other: Matrix): EjmlMatrix = - EjmlMatrix(toEjml().origin - other.toEjml().origin) + (toEjml().origin - other.toEjml().origin).wrapMatrix() public override fun scale(a: Matrix, value: Double): EjmlMatrix = a.toEjml().origin.scale(value).wrapMatrix() public override operator fun Matrix.times(value: Double): EjmlMatrix = - EjmlMatrix(toEjml().origin.scale(value)) + toEjml().origin.scale(value).wrapMatrix() override fun Vector.unaryMinus(): EjmlVector = toEjml().origin.negative().wrapVector() - override fun Matrix.plus(other: Matrix): Matrix { - TODO("Not yet implemented") - } + override fun Matrix.plus(other: Matrix): EjmlMatrix = + (toEjml().origin + other.toEjml().origin).wrapMatrix() - override fun Vector.plus(other: Vector): Vector { - TODO("Not yet implemented") - } + override fun Vector.plus(other: Vector): EjmlVector = + (toEjml().origin + other.toEjml().origin).wrapVector() - override fun Vector.minus(other: Vector): Vector { - TODO("Not yet implemented") - } + override fun Vector.minus(other: Vector): EjmlVector = + (toEjml().origin - other.toEjml().origin).wrapVector() - override fun Double.times(m: Matrix): Matrix { - TODO("Not yet implemented") - } + override fun Double.times(m: Matrix): EjmlMatrix = + m.toEjml().origin.scale(this).wrapMatrix() - override fun Vector.times(value: Double): Vector { - TODO("Not yet implemented") - } + override fun Vector.times(value: Double): EjmlVector = + toEjml().origin.scale(value).wrapVector() - override fun Double.times(v: Vector): Vector { - TODO("Not yet implemented") - } - - public override fun add(a: Matrix, b: Matrix): EjmlMatrix = - EjmlMatrix(a.toEjml().origin + b.toEjml().origin) + override fun Double.times(v: Vector): EjmlVector = + v.toEjml().origin.scale(this).wrapVector() } /** diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt index b93c430f9..640615cc9 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/RealMatrix.kt @@ -2,6 +2,7 @@ package space.kscience.kmath.real import space.kscience.kmath.linear.* import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.RealField import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.asIterable @@ -21,11 +22,11 @@ import kotlin.math.pow public typealias RealMatrix = Matrix -public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = +public fun realMatrix(rowNum: Int, colNum: Int, initializer: RealField.(i: Int, j: Int) -> Double): RealMatrix = LinearSpace.real.buildMatrix(rowNum, colNum, initializer) public fun Array.toMatrix(): RealMatrix { - return LinearSpace.real.buildMatrix(size, this[0].size) { row, col -> this[row][col] } + return LinearSpace.real.buildMatrix(size, this[0].size) { row, col -> this@toMatrix[row][col] } } public fun Sequence.toMatrix(): RealMatrix = toList().let { @@ -43,37 +44,37 @@ public fun RealMatrix.repeatStackVertical(n: Int): RealMatrix = public operator fun RealMatrix.times(double: Double): RealMatrix = LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> - this[row, col] * double + get(row, col) * double } public operator fun RealMatrix.plus(double: Double): RealMatrix = LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> - this[row, col] + double + get(row, col) + double } public operator fun RealMatrix.minus(double: Double): RealMatrix = LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> - this[row, col] - double + get(row, col) - double } public operator fun RealMatrix.div(double: Double): RealMatrix = LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> - this[row, col] / double + get(row, col) / double } public operator fun Double.times(matrix: RealMatrix): RealMatrix = LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> - this * matrix[row, col] + this@times * matrix[row, col] } public operator fun Double.plus(matrix: RealMatrix): RealMatrix = LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> - this + matrix[row, col] + this@plus + matrix[row, col] } public operator fun Double.minus(matrix: RealMatrix): RealMatrix = LinearSpace.real.buildMatrix(matrix.rowNum, matrix.colNum) { row, col -> - this - matrix[row, col] + this@minus - matrix[row, col] } // TODO: does this operation make sense? Should it be 'this/matrix[row, col]'? @@ -87,13 +88,13 @@ public operator fun Double.minus(matrix: RealMatrix): RealMatrix = @UnstableKMathAPI public operator fun RealMatrix.times(other: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this[row, col] * other[row, col] } + LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this@times[row, col] * other[row, col] } public operator fun RealMatrix.plus(other: RealMatrix): RealMatrix = - LinearSpace.real.add(this, other) + LinearSpace.real.run { this@plus + other } public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix = - LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this[row, col] - other[row, col] } + LinearSpace.real.buildMatrix(rowNum, colNum) { row, col -> this@minus[row, col] - other[row, col] } /* * Operations on columns @@ -102,14 +103,14 @@ public operator fun RealMatrix.minus(other: RealMatrix): RealMatrix = public inline fun RealMatrix.appendColumn(crossinline mapper: (Buffer) -> Double): RealMatrix = LinearSpace.real.buildMatrix(rowNum, colNum + 1) { row, col -> if (col < colNum) - this[row, col] + get(row, col) else mapper(rows[row]) } public fun RealMatrix.extractColumns(columnRange: IntRange): RealMatrix = LinearSpace.real.buildMatrix(rowNum, columnRange.count()) { row, col -> - this[row, columnRange.first + col] + this@extractColumns[row, columnRange.first + col] } public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix = diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt index cbfb364c1..5a66b55bd 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/dot.kt @@ -1,31 +1,31 @@ package space.kscience.kmath.real -import space.kscience.kmath.linear.BufferMatrix -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.RealBuffer +import space.kscience.kmath.linear.LinearSpace +import space.kscience.kmath.nd.Matrix /** * Optimized dot product for real matrices */ -public infix fun BufferMatrix.dot(other: BufferMatrix): BufferMatrix { - require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } - val resultArray = DoubleArray(this.rowNum * other.colNum) - - //convert to array to insure there is no memory indirection - fun Buffer.unsafeArray() = if (this is RealBuffer) - this.array - else - DoubleArray(size) { get(it) } - - val a = this.buffer.unsafeArray() - val b = other.buffer.unsafeArray() - - for (i in (0 until rowNum)) - for (j in (0 until other.colNum)) - for (k in (0 until colNum)) - resultArray[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j] - - val buffer = RealBuffer(resultArray) - return BufferMatrix(rowNum, other.colNum, buffer) +public infix fun Matrix.dot(other: Matrix): Matrix = LinearSpace.real.run{ + this@dot dot other +// require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } +// val resultArray = DoubleArray(this.rowNum * other.colNum) +// +// //convert to array to insure there is no memory indirection +// fun Buffer.unsafeArray() = if (this is RealBuffer) +// this.array +// else +// DoubleArray(size) { get(it) } +// +// val a = this.buffer.unsafeArray() +// val b = other.buffer.unsafeArray() +// +// for (i in (0 until rowNum)) +// for (j in (0 until other.colNum)) +// for (k in (0 until colNum)) +// resultArray[i * other.colNum + j] += a[i * colNum + k] * b[k * other.colNum + j] +// +// val buffer = RealBuffer(resultArray) +// return BufferMatrix(rowNum, other.colNum, buffer) } \ No newline at end of file