forked from kscience/kmath
Koma integration + multiple fixes to Matrix API
This commit is contained in:
parent
a9e06c261a
commit
baae18b223
@ -5,9 +5,10 @@ plugins {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api project(":kmath-core")
|
implementation project(":kmath-core")
|
||||||
api project(":kmath-coroutines")
|
implementation project(":kmath-coroutines")
|
||||||
api project(":kmath-commons")
|
implementation project(":kmath-commons")
|
||||||
|
implementation project(":kmath-koma")
|
||||||
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
|
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
|
||||||
//jmh project(':kmath-core')
|
//jmh project(':kmath-core')
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ fun main() {
|
|||||||
|
|
||||||
//commons-math
|
//commons-math
|
||||||
|
|
||||||
val cmSolver = CMSolver
|
val cmSolver = CMMatrixContext
|
||||||
|
|
||||||
val commonsTime = measureTimeMillis {
|
val commonsTime = measureTimeMillis {
|
||||||
val cm = matrix.toCM() //avoid overhead on conversion
|
val cm = matrix.toCM() //avoid overhead on conversion
|
||||||
|
@ -3,8 +3,6 @@ package scientifik.kmath.linear
|
|||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
import org.apache.commons.math3.linear.RealMatrix
|
import org.apache.commons.math3.linear.RealMatrix
|
||||||
import org.apache.commons.math3.linear.RealVector
|
import org.apache.commons.math3.linear.RealVector
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
|
||||||
|
|
||||||
inline class CMMatrix(val origin: RealMatrix) : Matrix<Double> {
|
inline class CMMatrix(val origin: RealMatrix) : Matrix<Double> {
|
||||||
override val rowNum: Int get() = origin.rowDimension
|
override val rowNum: Int get() = origin.rowDimension
|
||||||
@ -40,44 +38,51 @@ fun Point<Double>.toCM(): CMVector = if (this is CMVector) {
|
|||||||
CMVector(ArrayRealVector(array))
|
CMVector(ArrayRealVector(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
fun RealVector.toPoint() = DoubleBuffer(this.toArray())
|
fun RealVector.toPoint() = CMVector(this)
|
||||||
|
|
||||||
object CMMatrixContext : MatrixContext<Double, RealField> {
|
object CMMatrixContext : MatrixContext<Double>, LinearSolver<Double> {
|
||||||
override val elementContext: RealField get() = RealField
|
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): Matrix<Double> {
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): CMMatrix {
|
||||||
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
val array = Array(rows) { i -> DoubleArray(columns) { j -> initializer(i, j) } }
|
||||||
return CMMatrix(Array2DRowRealMatrix(array))
|
return CMMatrix(Array2DRowRealMatrix(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> Double): Point<Double> {
|
override fun solve(a: Matrix<Double>, b: Matrix<Double>): CMMatrix {
|
||||||
val array = DoubleArray(size, initializer)
|
val decomposition = LUDecomposition(a.toCM().origin)
|
||||||
return CMVector(ArrayRealVector(array))
|
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.dot(other: Matrix<Double>): Matrix<Double> = CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
override fun solve(a: Matrix<Double>, b: Point<Double>): CMVector {
|
||||||
|
val decomposition = LUDecomposition(a.toCM().origin)
|
||||||
|
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun inverse(a: Matrix<Double>): CMMatrix {
|
||||||
|
val decomposition = LUDecomposition(a.toCM().origin)
|
||||||
|
return decomposition.solver.inverse.toMatrix()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Matrix<Double>.dot(other: Matrix<Double>) =
|
||||||
|
CMMatrix(this.toCM().origin.multiply(other.toCM().origin))
|
||||||
|
|
||||||
|
override fun Matrix<Double>.dot(vector: Point<Double>): CMVector =
|
||||||
|
CMVector(this.toCM().origin.preMultiply(vector.toCM().origin))
|
||||||
|
|
||||||
|
|
||||||
|
override fun Matrix<Double>.unaryMinus(): CMMatrix =
|
||||||
|
produce(rowNum, colNum) { i, j -> -get(i, j) }
|
||||||
|
|
||||||
|
override fun Matrix<Double>.plus(b: Matrix<Double>) =
|
||||||
|
CMMatrix(this.toCM().origin.multiply(b.toCM().origin))
|
||||||
|
|
||||||
|
override fun Matrix<Double>.minus(b: Matrix<Double>) =
|
||||||
|
CMMatrix(this.toCM().origin.subtract(b.toCM().origin))
|
||||||
|
|
||||||
|
override fun Matrix<Double>.times(value: Double) =
|
||||||
|
CMMatrix(this.toCM().origin.scalarMultiply(value.toDouble()))
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
operator fun CMMatrix.plus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.add(other.origin))
|
||||||
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))
|
operator fun CMMatrix.minus(other: CMMatrix): CMMatrix = CMMatrix(this.origin.subtract(other.origin))
|
||||||
|
|
||||||
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin))
|
infix fun CMMatrix.dot(other: CMMatrix): CMMatrix = CMMatrix(this.origin.multiply(other.origin))
|
||||||
|
|
||||||
object CMSolver : LinearSolver<Double, RealField> {
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
|
||||||
val decomposition = LUDecomposition(a.toCM().origin)
|
|
||||||
return decomposition.solver.solve(b.toCM().origin).toMatrix()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun solve(a: Matrix<Double>, b: Point<Double>): Point<Double> {
|
|
||||||
val decomposition = LUDecomposition(a.toCM().origin)
|
|
||||||
return decomposition.solver.solve(b.toCM().origin).toPoint()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun inverse(a: Matrix<Double>): Matrix<Double> {
|
|
||||||
val decomposition = LUDecomposition(a.toCM().origin)
|
|
||||||
return decomposition.solver.inverse.toMatrix()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -9,7 +9,7 @@ import scientifik.kmath.structures.*
|
|||||||
class BufferMatrixContext<T : Any, R : Ring<T>>(
|
class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||||
override val elementContext: R,
|
override val elementContext: R,
|
||||||
private val bufferFactory: BufferFactory<T>
|
private val bufferFactory: BufferFactory<T>
|
||||||
) : MatrixContext<T, R> {
|
) : GenericMatrixContext<T, R> {
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): BufferMatrix<T> {
|
||||||
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
val buffer = bufferFactory(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||||
|
@ -65,7 +65,7 @@ class LUPDecomposition<T : Comparable<T>>(
|
|||||||
|
|
||||||
|
|
||||||
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
class LUSolver<T : Comparable<T>, F : Field<T>>(
|
||||||
val context: MatrixContext<T, F>,
|
val context: GenericMatrixContext<T, F>,
|
||||||
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
val bufferFactory: MutableBufferFactory<T> = ::boxing,
|
||||||
val singularityCheck: (T) -> Boolean
|
val singularityCheck: (T) -> Boolean
|
||||||
) : LinearSolver<T, F> {
|
) : LinearSolver<T, F> {
|
||||||
@ -201,6 +201,6 @@ class LUSolver<T : Comparable<T>, F : Field<T>>(
|
|||||||
override fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
override fun inverse(a: Matrix<T>): Matrix<T> = solve(a, context.one(a.rowNum, a.colNum))
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
val real = LUSolver(MatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
val real = LUSolver(GenericMatrixContext.real, MutableBuffer.Companion::auto) { it < 1e-11 }
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -3,7 +3,6 @@ package scientifik.kmath.linear
|
|||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.Norm
|
import scientifik.kmath.operations.Norm
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import scientifik.kmath.operations.Ring
|
|
||||||
import scientifik.kmath.structures.VirtualBuffer
|
import scientifik.kmath.structures.VirtualBuffer
|
||||||
import scientifik.kmath.structures.asSequence
|
import scientifik.kmath.structures.asSequence
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ import scientifik.kmath.structures.asSequence
|
|||||||
/**
|
/**
|
||||||
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors
|
||||||
*/
|
*/
|
||||||
interface LinearSolver<T : Any, R : Ring<T>> {
|
interface LinearSolver<T : Any> {
|
||||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||||
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
fun solve(a: Matrix<T>, b: Point<T>): Point<T> = solve(a, b.toMatrix()).toVector()
|
||||||
fun inverse(a: Matrix<T>): Matrix<T>
|
fun inverse(a: Matrix<T>): Matrix<T>
|
||||||
|
@ -8,28 +8,61 @@ import scientifik.kmath.structures.Buffer.Companion.boxing
|
|||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
|
||||||
interface MatrixContext<T : Any, R : Ring<T>> {
|
interface MatrixContext<T : Any> {
|
||||||
|
/**
|
||||||
|
* Produce a matrix with this context and given dimensions
|
||||||
|
*/
|
||||||
|
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
||||||
|
|
||||||
|
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T>
|
||||||
|
|
||||||
|
infix fun Matrix<T>.dot(vector: Point<T>): Point<T>
|
||||||
|
|
||||||
|
operator fun Matrix<T>.unaryMinus(): Matrix<T>
|
||||||
|
|
||||||
|
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T>
|
||||||
|
|
||||||
|
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T>
|
||||||
|
|
||||||
|
operator fun Matrix<T>.times(value: T): Matrix<T>
|
||||||
|
|
||||||
|
operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* Non-boxing double matrix
|
||||||
|
*/
|
||||||
|
val real = BufferMatrixContext(RealField, DoubleBufferFactory)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A structured matrix with custom buffer
|
||||||
|
*/
|
||||||
|
fun <T : Any, R : Ring<T>> buffered(
|
||||||
|
ring: R,
|
||||||
|
bufferFactory: BufferFactory<T> = ::boxing
|
||||||
|
): GenericMatrixContext<T, R> =
|
||||||
|
BufferMatrixContext(ring, bufferFactory)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatic buffered matrix, unboxed if it is possible
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): GenericMatrixContext<T, R> =
|
||||||
|
buffered(ring, Buffer.Companion::auto)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface GenericMatrixContext<T : Any, R : Ring<T>> : MatrixContext<T> {
|
||||||
/**
|
/**
|
||||||
* The ring context for matrix elements
|
* The ring context for matrix elements
|
||||||
*/
|
*/
|
||||||
val elementContext: R
|
val elementContext: R
|
||||||
|
|
||||||
/**
|
|
||||||
* Produce a matrix with this context and given dimensions
|
|
||||||
*/
|
|
||||||
fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce a point compatible with matrix space
|
* Produce a point compatible with matrix space
|
||||||
*/
|
*/
|
||||||
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||||
|
|
||||||
fun scale(a: Matrix<T>, k: Number): Matrix<T> {
|
override infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
||||||
//TODO create a special wrapper class for scaled matrices
|
|
||||||
return produce(a.rowNum, a.colNum) { i, j -> elementContext.run { a[i, j] * k } }
|
|
||||||
}
|
|
||||||
|
|
||||||
infix fun Matrix<T>.dot(other: Matrix<T>): Matrix<T> {
|
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
if (this.colNum != other.rowNum) error("Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})")
|
||||||
return produce(rowNum, other.colNum) { i, j ->
|
return produce(rowNum, other.colNum) { i, j ->
|
||||||
@ -41,7 +74,7 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
override infix fun Matrix<T>.dot(vector: Point<T>): Point<T> {
|
||||||
//TODO add typed error
|
//TODO add typed error
|
||||||
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
if (this.colNum != vector.size) error("Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})")
|
||||||
return point(rowNum) { i ->
|
return point(rowNum) { i ->
|
||||||
@ -52,15 +85,15 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Matrix<T>.unaryMinus() =
|
override operator fun Matrix<T>.unaryMinus() =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
produce(rowNum, colNum) { i, j -> elementContext.run { -get(i, j) } }
|
||||||
|
|
||||||
operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
override operator fun Matrix<T>.plus(b: Matrix<T>): Matrix<T> {
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] + [${b.rowNum},${b.colNum}]")
|
||||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
override operator fun Matrix<T>.minus(b: Matrix<T>): Matrix<T> {
|
||||||
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
if (rowNum != b.rowNum || colNum != b.colNum) error("Matrix operation dimension mismatch. [$rowNum,$colNum] - [${b.rowNum},${b.colNum}]")
|
||||||
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
return produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) + b[i, j] } }
|
||||||
}
|
}
|
||||||
@ -68,27 +101,10 @@ interface MatrixContext<T : Any, R : Ring<T>> {
|
|||||||
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
operator fun Matrix<T>.times(number: Number): Matrix<T> =
|
||||||
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * number } }
|
||||||
|
|
||||||
operator fun Number.times(m: Matrix<T>): Matrix<T> = m * this
|
operator fun Number.times(matrix: Matrix<T>): Matrix<T> = matrix * this
|
||||||
|
|
||||||
|
override fun Matrix<T>.times(value: T): Matrix<T> =
|
||||||
companion object {
|
produce(rowNum, colNum) { i, j -> elementContext.run { get(i, j) * value } }
|
||||||
/**
|
|
||||||
* Non-boxing double matrix
|
|
||||||
*/
|
|
||||||
val real = BufferMatrixContext(RealField, DoubleBufferFactory)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A structured matrix with custom buffer
|
|
||||||
*/
|
|
||||||
fun <T : Any, R : Ring<T>> buffered(ring: R, bufferFactory: BufferFactory<T> = ::boxing): MatrixContext<T, R> =
|
|
||||||
BufferMatrixContext(ring, bufferFactory)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Automatic buffered matrix, unboxed if it is possible
|
|
||||||
*/
|
|
||||||
inline fun <reified T : Any, R : Ring<T>> auto(ring: R): MatrixContext<T, R> =
|
|
||||||
buffered(ring, Buffer.Companion::auto)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -173,7 +189,7 @@ inline fun <reified T : Any> Matrix<*>.getFeature(): T? = features.filterIsInsta
|
|||||||
/**
|
/**
|
||||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.one(rows: Int, columns: Int): Matrix<T> =
|
||||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||||
if (i == j) elementContext.one else elementContext.zero
|
if (i == j) elementContext.one else elementContext.zero
|
||||||
}
|
}
|
||||||
@ -182,7 +198,7 @@ fun <T : Any, R : Ring<T>> MatrixContext<T, R>.one(rows: Int, columns: Int): Mat
|
|||||||
/**
|
/**
|
||||||
* A virtual matrix of zeroes
|
* A virtual matrix of zeroes
|
||||||
*/
|
*/
|
||||||
fun <T : Any, R : Ring<T>> MatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||||
VirtualMatrix<T>(rows, columns) { i, j ->
|
VirtualMatrix<T>(rows, columns) { i, j ->
|
||||||
elementContext.zero
|
elementContext.zero
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ class MatrixTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testTranspose() {
|
fun testTranspose() {
|
||||||
val matrix = MatrixContext.real.one(3, 3)
|
val matrix = GenericMatrixContext.real.one(3, 3)
|
||||||
val transposed = matrix.transpose()
|
val transposed = matrix.transpose()
|
||||||
assertEquals((matrix as BufferMatrix).buffer, (transposed as BufferMatrix).buffer)
|
assertEquals((matrix as BufferMatrix).buffer, (transposed as BufferMatrix).buffer)
|
||||||
assertEquals(matrix, transposed)
|
assertEquals(matrix, transposed)
|
||||||
@ -36,7 +36,7 @@ class MatrixTest {
|
|||||||
|
|
||||||
val matrix1 = vector1.toMatrix()
|
val matrix1 = vector1.toMatrix()
|
||||||
val matrix2 = vector2.toMatrix().transpose()
|
val matrix2 = vector2.toMatrix().transpose()
|
||||||
val product = MatrixContext.real.run { matrix1 dot matrix2 }
|
val product = GenericMatrixContext.real.run { matrix1 dot matrix2 }
|
||||||
|
|
||||||
|
|
||||||
assertEquals(5.0, product[1, 0])
|
assertEquals(5.0, product[1, 0])
|
||||||
|
@ -6,7 +6,7 @@ import kotlin.test.assertEquals
|
|||||||
class RealLUSolverTest {
|
class RealLUSolverTest {
|
||||||
@Test
|
@Test
|
||||||
fun testInvertOne() {
|
fun testInvertOne() {
|
||||||
val matrix = MatrixContext.real.one(2, 2)
|
val matrix = GenericMatrixContext.real.one(2, 2)
|
||||||
val inverted = LUSolver.real.inverse(matrix)
|
val inverted = LUSolver.real.inverse(matrix)
|
||||||
assertEquals(matrix, inverted)
|
assertEquals(matrix, inverted)
|
||||||
}
|
}
|
||||||
@ -25,7 +25,7 @@ class RealLUSolverTest {
|
|||||||
assertEquals(8.0, decomposition.determinant)
|
assertEquals(8.0, decomposition.determinant)
|
||||||
|
|
||||||
//Check decomposition
|
//Check decomposition
|
||||||
with(MatrixContext.real) {
|
with(GenericMatrixContext.real) {
|
||||||
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
assertEquals(decomposition.p dot matrix, decomposition.l dot decomposition.u)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,27 +1,70 @@
|
|||||||
package scientifik.kmath.linear
|
package scientifik.kmath.linear
|
||||||
|
|
||||||
import scientifik.kmath.operations.Ring
|
import koma.extensions.fill
|
||||||
|
import koma.matrix.MatrixFactory
|
||||||
|
|
||||||
class KomaMatrixContext<T: Any, R: Ring<T>> : MatrixContext<T,R> {
|
class KomaMatrixContext<T : Any>(val factory: MatrixFactory<koma.matrix.Matrix<T>>) : MatrixContext<T>,
|
||||||
override val elementContext: R
|
LinearSolver<T> {
|
||||||
get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates.
|
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T> {
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
||||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
||||||
|
|
||||||
|
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
|
||||||
|
this
|
||||||
|
} else {
|
||||||
|
produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> {
|
fun Point<T>.toKoma(): KomaVector<T> = if (this is KomaVector) {
|
||||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
this
|
||||||
|
} else {
|
||||||
|
KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
override fun Matrix<T>.dot(other: Matrix<T>) =
|
||||||
|
KomaMatrix(this.toKoma().origin * other.toKoma().origin)
|
||||||
|
|
||||||
|
override fun Matrix<T>.dot(vector: Point<T>) =
|
||||||
|
KomaVector(this.toKoma().origin * vector.toKoma().origin)
|
||||||
|
|
||||||
|
override fun Matrix<T>.unaryMinus() =
|
||||||
|
KomaMatrix(this.toKoma().origin.unaryMinus())
|
||||||
|
|
||||||
|
override fun Matrix<T>.plus(b: Matrix<T>) =
|
||||||
|
KomaMatrix(this.toKoma().origin + b.toKoma().origin)
|
||||||
|
|
||||||
|
override fun Matrix<T>.minus(b: Matrix<T>) =
|
||||||
|
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
||||||
|
|
||||||
|
override fun Matrix<T>.times(value: T) =
|
||||||
|
KomaMatrix(this.toKoma().origin * value)
|
||||||
|
|
||||||
|
|
||||||
|
override fun solve(a: Matrix<T>, b: Matrix<T>) =
|
||||||
|
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
||||||
|
|
||||||
|
override fun inverse(a: Matrix<T>) =
|
||||||
|
KomaMatrix(a.toKoma().origin.inv())
|
||||||
}
|
}
|
||||||
|
|
||||||
inline class KomaMatrix<T : Any>(val matrix: koma.matrix.Matrix<T>) : Matrix<T> {
|
inline class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>) : Matrix<T> {
|
||||||
override val rowNum: Int get() = matrix.numRows()
|
override val rowNum: Int get() = origin.numRows()
|
||||||
override val colNum: Int get() = matrix.numCols()
|
override val colNum: Int get() = origin.numCols()
|
||||||
override val features: Set<MatrixFeature> get() = emptySet()
|
override val features: Set<MatrixFeature> get() = emptySet()
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
||||||
override inline fun get(i: Int, j: Int): T = matrix.getGeneric(i, j)
|
}
|
||||||
|
|
||||||
|
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
||||||
|
init {
|
||||||
|
if (origin.numCols() != 1) error("Only single column matrices are allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
override val size: Int get() = origin.numRows()
|
||||||
|
|
||||||
|
override fun get(index: Int): T = origin.getGeneric(index)
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<T> = origin.toIterable().iterator()
|
||||||
|
}
|
||||||
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user