diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index e7345163f..68961a2db 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -55,6 +55,26 @@ abstract class MatrixSpace(val rows: Int, val columns: Int, val field: (0 until a.columns).asSequence().map { k -> field.multiply(a[i, k], b[k, j]) }.reduce { first, second -> field.add(first, second) } } } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is MatrixSpace<*>) return false + + if (rows != other.rows) return false + if (columns != other.columns) return false + if (field != other.field) return false + + return true + } + + override fun hashCode(): Int { + var result = rows + result = 31 * result + columns + result = 31 * result + field.hashCode() + return result + } + + } infix fun Matrix.dot(b: Matrix): Matrix = this.context.multiply(this, b) @@ -90,8 +110,38 @@ interface Matrix : SpaceElement, MatrixSpace> { } companion object { - fun one(rows: Int, columns: Int, field: Field): Matrix { - return matrix(rows, columns, field) { i, j -> if (i == j) field.one else field.zero } + + /** + * Create [ArrayMatrix] with custom field + */ + fun of(rows: Int, columns: Int, field: Field, initializer: (Int, Int) -> T) = + ArrayMatrix(ArrayMatrixSpace(rows, columns, field), initializer) + + /** + * Create [ArrayMatrix] of doubles. The implementation in general should be faster than generic one due to boxing. + */ + fun ofReal(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = + ArrayMatrix(ArrayMatrixSpace(rows, columns, DoubleField, realNDFieldFactory), initializer) + + /** + * Create a diagonal value matrix. By default value equals [Field.one]. + */ + fun diagonal(rows: Int, columns: Int, field: Field, values: (Int) -> T = { field.one }): Matrix { + return of(rows, columns, field) { i, j -> if (i == j) values(i) else field.zero } + } + + /** + * Equality check on two generic matrices + */ + fun equals(mat1: Matrix<*>, mat2: Matrix<*>): Boolean { + if (mat1 === mat2) return true + if (mat1.context != mat2.context) return false + for (i in 0 until mat1.rows) { + for (j in 0 until mat2.columns) { + if (mat1[i, j] != mat2[i, j]) return false + } + } + return true } } } @@ -117,6 +167,30 @@ interface Vector : SpaceElement, VectorSpace> { get() = context.size operator fun get(i: Int): T + + companion object { + /** + * Create vector with custom field + */ + fun of(size: Int, field: Field, initializer: (Int) -> T) = + ArrayVector(ArrayVectorSpace(size, field), initializer) + + /** + * Create vector of [Double] + */ + fun ofReal(size: Int, initializer: (Int) -> Double) = + ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer) + + + fun equals(v1: Vector<*>, v2: Vector<*>): Boolean { + if (v1 === v2) return true + if (v1.context != v2.context) return false + for (i in 0 until v2.size) { + if (v1[i] != v2[i]) return false + } + return true + } + } } @@ -195,40 +269,15 @@ class ArrayVector internal constructor(override val context: ArrayVecto interface LinearSolver { fun solve(a: Matrix, b: Matrix): Matrix fun solve(a: Matrix, b: Vector): Vector = solve(a, b.toMatrix()).toVector() - fun inverse(a: Matrix): Matrix = solve(a, Matrix.one(a.rows, a.columns, a.context.field)) + fun inverse(a: Matrix): Matrix = solve(a, Matrix.diagonal(a.rows, a.columns, a.context.field)) } -/** - * Create vector with custom field - */ -fun vector(size: Int, field: Field, initializer: (Int) -> T) = - ArrayVector(ArrayVectorSpace(size, field), initializer) - -/** - * Create vector of [Double] - */ -fun realVector(size: Int, initializer: (Int) -> Double) = - ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer) - /** * Convert vector to array (copying content of array) */ -fun Array.asVector(field: Field) = vector(size, field) { this[it] } - -fun DoubleArray.asVector() = realVector(this.size) { this[it] } - -/** - * Create [ArrayMatrix] with custom field - */ -fun matrix(rows: Int, columns: Int, field: Field, initializer: (Int, Int) -> T) = - ArrayMatrix(ArrayMatrixSpace(rows, columns, field), initializer) - -/** - * Create [ArrayMatrix] of doubles. - */ -fun realMatrix(rows: Int, columns: Int, initializer: (Int, Int) -> Double) = - ArrayMatrix(ArrayMatrixSpace(rows, columns, DoubleField, realNDFieldFactory), initializer) +fun Array.toVector(field: Field) = Vector.of(size, field) { this[it] } +fun DoubleArray.toVector() = Vector.ofReal(this.size) { this[it] } /** * Convert matrix to vector if it is possible @@ -243,7 +292,7 @@ fun Matrix.toVector(): Vector { // //Generic vector // vector(rows, context.field) { get(it, 0) } // } - vector(rows, context.field) { get(it, 0) } + Vector.of(rows, context.field) { get(it, 0) } } else -> error("Can't convert matrix with more than one column to vector") } @@ -257,6 +306,6 @@ fun Vector.toMatrix(): Matrix { // //Generic vector // matrix(size, 1, context.field) { i, j -> get(i) } // } - return matrix(size, 1, context.field) { i, j -> get(i) } + return Matrix.of(size, 1, context.field) { i, j -> get(i) } } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt index d2237b80a..472d5262f 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/linear/RealLUSolverTest.kt @@ -2,13 +2,20 @@ package scientifik.kmath.linear import scientifik.kmath.operations.DoubleField import kotlin.test.Test -import kotlin.test.assertEquals +import kotlin.test.assertTrue class RealLUSolverTest { @Test - fun testInvert() { - val matrix = Matrix.one(2, 2, DoubleField) + fun testInvertOne() { + val matrix = Matrix.diagonal(2, 2, DoubleField) val inverted = RealLUSolver.inverse(matrix) - assertEquals(1.0, inverted[0, 0]) + assertTrue { Matrix.equals(matrix,inverted) } } + +// @Test +// fun testInvert() { +// val matrix = realMatrix(2,2){} +// val inverted = RealLUSolver.inverse(matrix) +// assertTrue { Matrix.equals(matrix,inverted) } +// } } \ No newline at end of file