forked from kscience/kmath
Koma integration + multiple fixes to Matrix API
This commit is contained in:
parent
a9e06c261a
commit
baae18b223
@ -5,9 +5,10 @@ plugins {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api project(":kmath-core")
|
||||
api project(":kmath-coroutines")
|
||||
api project(":kmath-commons")
|
||||
implementation project(":kmath-core")
|
||||
implementation project(":kmath-coroutines")
|
||||
implementation project(":kmath-commons")
|
||||
implementation project(":kmath-koma")
|
||||
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
|
||||
//jmh project(':kmath-core')
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ fun main() {
|
||||
|
||||
//commons-math
|
||||
|
||||
val cmSolver = CMSolver
|
||||
val cmSolver = CMMatrixContext
|
||||
|
||||
val commonsTime = measureTimeMillis {
|
||||
val cm = matrix.toCM() //avoid overhead on conversion
|
||||
|
@ -3,8 +3,6 @@ package scientifik.kmath.linear
|
||||
import org.apache.commons.math3.linear.*
|
||||
import org.apache.commons.math3.linear.RealMatrix
|
||||
import org.apache.commons.math3.linear.RealVector
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
|
||||
inline class CMMatrix(val origin: RealMatrix) : Matrix<Double> {
|
||||
override val rowNum: Int get() = origin.rowDimension
|
||||
@ -40,44 +38,51 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
|
||||
CMVector(ArrayRealVector(array))
|
||||
}
|
||||
|
||||
fun RealVector.toPoint() = DoubleBuffer(this.toArray())
|
||||
fun RealVector.toPoint() = CMVector(this)
|
||||
|
||||
object CMMatrixContext : MatrixContext<Double, RealField> {
|
||||
override val elementContext: RealField get() = RealField
|
||||
object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
|
||||
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
||||
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
||||
return CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
override fun point(size: Int, initializer: (Int) -> Double): Point<Double> {
|
||||
val array = DoubleArray(size, initializer)
|
||||
return CMVector(ArrayRealVector(array))
|
||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<Double>): CMMatrix {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.inverse.toMatrix()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||
|
||||
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
||||
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
||||
|
||||
|
||||
override fun Matrix<Double>.unaryMinus(): CMMatrix =
|
||||
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
||||
|
||||
override fun Matrix<Double>.plus(b: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.multiply(b.toCM().origin))
|
||||
|
||||
override fun Matrix<Double>.minus(b: Matrix<Double>) =
|
||||
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
||||
|
||||
override fun Matrix<Double>.times(value: Double) =
|
||||
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
|
||||
}
|
||||
|
||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
||||
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))
|
||||
|
||||
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin))
|
||||
|
||||
object CMSolver : LinearSolver<Double, RealField> {
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
||||
}
|
||||
|
||||
override fun solve(a: Matrix<Double>, b: Point<Double>): Point<Double> {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
||||
}
|
||||
|
||||
override fun inverse(a: Matrix<Double>): Matrix<Double> {
|
||||
val decomposition = LUDecomposition(a.toCM().origin)
|
||||
return decomposition.solver.inverse.toMatrix()
|
||||
}
|
||||
|
||||
}
|
||||
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin))
|
@ -9,7 +9,7 @@ import scientifik.kmath.structures.*
|
||||
class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||
override val elementContext: R,
|
||||
private val bufferFactory: BufferFactory<T>
|
||||
) : MatrixContext<T, R> {
|
||||
) : GenericMatrixContext<T, R> {
|
||||
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
|
||||
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
|
@ -65,7 +65,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
||||
|
||||
|
||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
val context: MatrixContext<T, F>,
|
||||
val context: GenericMatrixContext<T, F>,
|
||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||
val singularityCheck: (T) -> Boolean
|
||||
) : LinearSolver<T, F> {
|
||||
@ -201,6 +201,6 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||
override fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
||||
|
||||
companion object {
|
||||
val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
||||
val real = LUSolver(GenericMatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
||||
}
|
||||
}
|
@ -3,7 +3,6 @@ package scientifik.kmath.linear
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Norm
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.structures.VirtualBuffer
|
||||
import scientifik.kmath.structures.asSequence
|
||||
|
||||
@ -11,7 +10,7 @@ import scientifik.kmath.structures.asSequence
|
||||
/**
|
||||
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||
*/
|
||||
interface LinearSolver<T : Any, R : Ring<T>> {
|
||||
interface LinearSolver<T : Any> {
|
||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
||||
fun inverse(a: Matrix<T>): Matrix<T>
|
||||
|
@ -8,28 +8,61 @@ import scientifik.kmath.structures.Buffer.Companion.boxing
|
||||
import kotlin.math.sqrt
|
||||
|
||||
|
||||
interface MatrixContext<T : Any, R : Ring<T>> {
|
||||
interface MatrixContext<T : Any> {
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
*/
|
||||
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||
|
||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||
|
||||
infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||
|
||||
operator fun Matrix<T>.unaryMinus(): Matrix<T>
|
||||
|
||||
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T>
|
||||
|
||||
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T>
|
||||
|
||||
operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||
|
||||
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
val real = BufferMatrixContext(RealField, DoubleBufferFactory)
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> buffered(
|
||||
ring: R,
|
||||
bufferFactory: BufferFactory<T> = ::boxing
|
||||
): GenericMatrixContext<T, R> =
|
||||
BufferMatrixContext(ring, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered matrix, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
|
||||
buffered(ring, Buffer.Companion::auto)
|
||||
}
|
||||
}
|
||||
|
||||
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||
/**
|
||||
* The ring context for matrix elements
|
||||
*/
|
||||
val elementContext: R
|
||||
|
||||
/**
|
||||
* Produce a matrix with this context and given dimensions
|
||||
*/
|
||||
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space
|
||||
*/
|
||||
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||
|
||||
fun scale(a: Matrix<T>, k: Number): Matrix<T> {
|
||||
//TODO create a special wrapper class for scaled matrices
|
||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
|
||||
}
|
||||
|
||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||
//TODO add typed error
|
||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
||||
return produce(rowNum, other.colNum) { i, j ->
|
||||
@ -41,7 +74,7 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||
//TODO add typed error
|
||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
||||
return point(rowNum) { i ->
|
||||
@ -52,15 +85,15 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
operator fun Matrix<T>.unaryMinus() =
|
||||
override operator fun Matrix<T>.unaryMinus() =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||
|
||||
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
||||
override operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||
}
|
||||
|
||||
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||
}
|
||||
@ -68,27 +101,10 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
||||
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
||||
|
||||
operator fun Number.times(m: Matrix<T>): Matrix<T> = m * this
|
||||
operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this
|
||||
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
val real = BufferMatrixContext(RealField, DoubleBufferFactory)
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> buffered(ring: R, bufferFactory: BufferFactory<T> = ::boxing): MatrixContext<T, R> =
|
||||
BufferMatrixContext(ring, bufferFactory)
|
||||
|
||||
/**
|
||||
* Automatic buffered matrix, unboxed if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): MatrixContext<T, R> =
|
||||
buffered(ring, Buffer.Companion::auto)
|
||||
}
|
||||
override fun Matrix<T>.times(value: T): Matrix<T> =
|
||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -173,7 +189,7 @@ inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInsta
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
@ -182,7 +198,7 @@ fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Mat
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||
elementContext.zero
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ class MatrixTest {
|
||||
|
||||
@Test
|
||||
fun testTranspose() {
|
||||
val matrix = MatrixContext.real.one(3, 3)
|
||||
val matrix = GenericMatrixContext.real.one(3, 3)
|
||||
val transposed = matrix.transpose()
|
||||
assertEquals((matrix as BufferMatrix).buffer, (transposed as BufferMatrix).buffer)
|
||||
assertEquals(matrix, transposed)
|
||||
@ -36,7 +36,7 @@ class MatrixTest {
|
||||
|
||||
val matrix1 = vector1.toMatrix()
|
||||
val matrix2 = vector2.toMatrix().transpose()
|
||||
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
||||
val product = GenericMatrixContext.real.run { matrix1 dot matrix2 }
|
||||
|
||||
|
||||
assertEquals(5.0, product[1, 0])
|
||||
|
@ -6,7 +6,7 @@ import kotlin.test.assertEquals
|
||||
class RealLUSolverTest {
|
||||
@Test
|
||||
fun testInvertOne() {
|
||||
val matrix = MatrixContext.real.one(2, 2)
|
||||
val matrix = GenericMatrixContext.real.one(2, 2)
|
||||
val inverted = LUSolver.real.inverse(matrix)
|
||||
assertEquals(matrix, inverted)
|
||||
}
|
||||
@ -25,7 +25,7 @@ class RealLUSolverTest {
|
||||
assertEquals(8.0, decomposition.determinant)
|
||||
|
||||
//Check decomposition
|
||||
with(MatrixContext.real) {
|
||||
with(GenericMatrixContext.real) {
|
||||
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
||||
}
|
||||
|
||||
|
@ -1,27 +1,70 @@
|
||||
package scientifik.kmath.linear
|
||||
|
||||
import scientifik.kmath.operations.Ring
|
||||
import koma.extensions.fill
|
||||
import koma.matrix.MatrixFactory
|
||||
|
||||
class KomaMatrixContext<T: Any, R: Ring<T>> : MatrixContext<T,R> {
|
||||
override val elementContext: R
|
||||
get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates.
|
||||
class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>,
|
||||
LinearSolver<T> {
|
||||
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T> {
|
||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
||||
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
||||
|
||||
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
|
||||
this
|
||||
} else {
|
||||
produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
}
|
||||
|
||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> {
|
||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
fun Point<T>.toKoma(): KomaVector<T> = if (this is KomaVector) {
|
||||
this
|
||||
} else {
|
||||
KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) })
|
||||
}
|
||||
|
||||
|
||||
override fun Matrix<T>.dot(other: Matrix<T>) =
|
||||
KomaMatrix(this.toKoma().origin * other.toKoma().origin)
|
||||
|
||||
override fun Matrix<T>.dot(vector: Point<T>) =
|
||||
KomaVector(this.toKoma().origin * vector.toKoma().origin)
|
||||
|
||||
override fun Matrix<T>.unaryMinus() =
|
||||
KomaMatrix(this.toKoma().origin.unaryMinus())
|
||||
|
||||
override fun Matrix<T>.plus(b: Matrix<T>) =
|
||||
KomaMatrix(this.toKoma().origin + b.toKoma().origin)
|
||||
|
||||
override fun Matrix<T>.minus(b: Matrix<T>) =
|
||||
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
||||
|
||||
override fun Matrix<T>.times(value: T) =
|
||||
KomaMatrix(this.toKoma().origin * value)
|
||||
|
||||
|
||||
override fun solve(a: Matrix<T>, b: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
||||
|
||||
override fun inverse(a: Matrix<T>) =
|
||||
KomaMatrix(a.toKoma().origin.inv())
|
||||
}
|
||||
|
||||
inline class KomaMatrix<T : Any>(val matrix: koma.matrix.Matrix<T>) : Matrix<T> {
|
||||
override val rowNum: Int get() = matrix.numRows()
|
||||
override val colNum: Int get() = matrix.numCols()
|
||||
inline class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>) : Matrix<T> {
|
||||
override val rowNum: Int get() = origin.numRows()
|
||||
override val colNum: Int get() = origin.numCols()
|
||||
override val features: Set<MatrixFeature> get() = emptySet()
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun get(i: Int, j: Int): T = matrix.getGeneric(i, j)
|
||||
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||
}
|
||||
|
||||
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
||||
init {
|
||||
if (origin.numCols() != 1) error("Only single column matrices are allowed")
|
||||
}
|
||||
|
||||
override val size: Int get() = origin.numRows()
|
||||
|
||||
override fun get(index: Int): T = origin.getGeneric(index)
|
||||
|
||||
override fun iterator(): Iterator<T> = origin.toIterable().iterator()
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user