forked from kscience/kmath
Basic matrix inversion
This commit is contained in:
parent
0cde0cfea5
commit
f894175897
@ -55,6 +55,26 @@ abstract class MatrixSpace<T : Any>(val rows: Int, val columns: Int, val field:
|
|||||||
(0 until a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
(0 until a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other !is MatrixSpace<*>) return false
|
||||||
|
|
||||||
|
if (rows != other.rows) return false
|
||||||
|
if (columns != other.columns) return false
|
||||||
|
if (field != other.field) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = rows
|
||||||
|
result = 31 * result + columns
|
||||||
|
result = 31 * result + field.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
infix fun <T : Any> Matrix<T>.dot(b: Matrix<T>): Matrix<T> = this.context.multiply(this, b)
|
infix fun <T : Any> Matrix<T>.dot(b: Matrix<T>): Matrix<T> = this.context.multiply(this, b)
|
||||||
@ -90,8 +110,38 @@ interface Matrix<T : Any> : SpaceElement<Matrix<T>, MatrixSpace<T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun <T : Any> one(rows: Int, columns: Int, field: Field<T>): Matrix<T> {
|
|
||||||
return matrix(rows, columns, field) { i, j -> if (i == j) field.one else field.zero }
|
/**
|
||||||
|
* Create [ArrayMatrix] with custom field
|
||||||
|
*/
|
||||||
|
fun <T : Any> of(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) =
|
||||||
|
ArrayMatrix(ArrayMatrixSpace(rows, columns, field), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create [ArrayMatrix] of doubles. The implementation in general should be faster than generic one due to boxing.
|
||||||
|
*/
|
||||||
|
fun ofReal(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
||||||
|
ArrayMatrix(ArrayMatrixSpace(rows, columns, DoubleField, realNDFieldFactory), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a diagonal value matrix. By default value equals [Field.one].
|
||||||
|
*/
|
||||||
|
fun <T : Any> diagonal(rows: Int, columns: Int, field: Field<T>, values: (Int) -> T = { field.one }): Matrix<T> {
|
||||||
|
return of(rows, columns, field) { i, j -> if (i == j) values(i) else field.zero }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Equality check on two generic matrices
|
||||||
|
*/
|
||||||
|
fun equals(mat1: Matrix<*>, mat2: Matrix<*>): Boolean {
|
||||||
|
if (mat1 === mat2) return true
|
||||||
|
if (mat1.context != mat2.context) return false
|
||||||
|
for (i in 0 until mat1.rows) {
|
||||||
|
for (j in 0 until mat2.columns) {
|
||||||
|
if (mat1[i, j] != mat2[i, j]) return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,6 +167,30 @@ interface Vector<T : Any> : SpaceElement<Vector<T>, VectorSpace<T>> {
|
|||||||
get() = context.size
|
get() = context.size
|
||||||
|
|
||||||
operator fun get(i: Int): T
|
operator fun get(i: Int): T
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* Create vector with custom field
|
||||||
|
*/
|
||||||
|
fun <T : Any> of(size: Int, field: Field<T>, initializer: (Int) -> T) =
|
||||||
|
ArrayVector(ArrayVectorSpace(size, field), initializer)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create vector of [Double]
|
||||||
|
*/
|
||||||
|
fun ofReal(size: Int, initializer: (Int) -> Double) =
|
||||||
|
ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer)
|
||||||
|
|
||||||
|
|
||||||
|
fun equals(v1: Vector<*>, v2: Vector<*>): Boolean {
|
||||||
|
if (v1 === v2) return true
|
||||||
|
if (v1.context != v2.context) return false
|
||||||
|
for (i in 0 until v2.size) {
|
||||||
|
if (v1[i] != v2[i]) return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -195,40 +269,15 @@ class ArrayVector<T : Any> internal constructor(override val context: ArrayVecto
|
|||||||
interface LinearSolver<T : Any> {
|
interface LinearSolver<T : Any> {
|
||||||
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T>
|
||||||
fun solve(a: Matrix<T>, b: Vector<T>): Vector<T> = solve(a, b.toMatrix()).toVector()
|
fun solve(a: Matrix<T>, b: Vector<T>): Vector<T> = solve(a, b.toMatrix()).toVector()
|
||||||
fun inverse(a: Matrix<T>): Matrix<T> = solve(a, Matrix.one(a.rows, a.columns, a.context.field))
|
fun inverse(a: Matrix<T>): Matrix<T> = solve(a, Matrix.diagonal(a.rows, a.columns, a.context.field))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Create vector with custom field
|
|
||||||
*/
|
|
||||||
fun <T : Any> vector(size: Int, field: Field<T>, initializer: (Int) -> T) =
|
|
||||||
ArrayVector(ArrayVectorSpace(size, field), initializer)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create vector of [Double]
|
|
||||||
*/
|
|
||||||
fun realVector(size: Int, initializer: (Int) -> Double) =
|
|
||||||
ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert vector to array (copying content of array)
|
* Convert vector to array (copying content of array)
|
||||||
*/
|
*/
|
||||||
fun <T : Any> Array<T>.asVector(field: Field<T>) = vector(size, field) { this[it] }
|
fun <T : Any> Array<T>.toVector(field: Field<T>) = Vector.of(size, field) { this[it] }
|
||||||
|
|
||||||
fun DoubleArray.asVector() = realVector(this.size) { this[it] }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create [ArrayMatrix] with custom field
|
|
||||||
*/
|
|
||||||
fun <T : Any> matrix(rows: Int, columns: Int, field: Field<T>, initializer: (Int, Int) -> T) =
|
|
||||||
ArrayMatrix(ArrayMatrixSpace(rows, columns, field), initializer)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create [ArrayMatrix] of doubles.
|
|
||||||
*/
|
|
||||||
fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) =
|
|
||||||
ArrayMatrix(ArrayMatrixSpace(rows, columns, DoubleField, realNDFieldFactory), initializer)
|
|
||||||
|
|
||||||
|
fun DoubleArray.toVector() = Vector.ofReal(this.size) { this[it] }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert matrix to vector if it is possible
|
* Convert matrix to vector if it is possible
|
||||||
@ -243,7 +292,7 @@ fun <T : Any> Matrix<T>.toVector(): Vector<T> {
|
|||||||
// //Generic vector
|
// //Generic vector
|
||||||
// vector(rows, context.field) { get(it, 0) }
|
// vector(rows, context.field) { get(it, 0) }
|
||||||
// }
|
// }
|
||||||
vector(rows, context.field) { get(it, 0) }
|
Vector.of(rows, context.field) { get(it, 0) }
|
||||||
}
|
}
|
||||||
else -> error("Can't convert matrix with more than one column to vector")
|
else -> error("Can't convert matrix with more than one column to vector")
|
||||||
}
|
}
|
||||||
@ -257,6 +306,6 @@ fun <T : Any> Vector<T>.toMatrix(): Matrix<T> {
|
|||||||
// //Generic vector
|
// //Generic vector
|
||||||
// matrix(size, 1, context.field) { i, j -> get(i) }
|
// matrix(size, 1, context.field) { i, j -> get(i) }
|
||||||
// }
|
// }
|
||||||
return matrix(size, 1, context.field) { i, j -> get(i) }
|
return Matrix.of(size, 1, context.field) { i, j -> get(i) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,13 +2,20 @@ package scientifik.kmath.linear
|
|||||||
|
|
||||||
import scientifik.kmath.operations.DoubleField
|
import scientifik.kmath.operations.DoubleField
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
class RealLUSolverTest {
|
class RealLUSolverTest {
|
||||||
@Test
|
@Test
|
||||||
fun testInvert() {
|
fun testInvertOne() {
|
||||||
val matrix = Matrix.one(2, 2, DoubleField)
|
val matrix = Matrix.diagonal(2, 2, DoubleField)
|
||||||
val inverted = RealLUSolver.inverse(matrix)
|
val inverted = RealLUSolver.inverse(matrix)
|
||||||
assertEquals(1.0, inverted[0, 0])
|
assertTrue { Matrix.equals(matrix,inverted) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @Test
|
||||||
|
// fun testInvert() {
|
||||||
|
// val matrix = realMatrix(2,2){}
|
||||||
|
// val inverted = RealLUSolver.inverse(matrix)
|
||||||
|
// assertTrue { Matrix.equals(matrix,inverted) }
|
||||||
|
// }
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user