package scientifik.kmath.linear import koma.extensions.fill import koma.matrix.MatrixFactory import scientifik.kmath.operations.Space import scientifik.kmath.operations.invoke import scientifik.kmath.structures.Matrix import scientifik.kmath.structures.NDStructure class KomaMatrixContext( private val factory: MatrixFactory>, private val space: Space ) : MatrixContext { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): KomaMatrix = KomaMatrix(factory.zeros(rows, columns).fill(initializer)) fun Matrix.toKoma(): KomaMatrix = if (this is KomaMatrix) { this } else { produce(rowNum, colNum) { i, j -> get(i, j) } } fun Point.toKoma(): KomaVector = if (this is KomaVector) { this } else { KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) }) } override fun Matrix.dot(other: Matrix): KomaMatrix = KomaMatrix(toKoma().origin * other.toKoma().origin) override fun Matrix.dot(vector: Point): KomaVector = KomaVector(toKoma().origin * vector.toKoma().origin) override operator fun Matrix.unaryMinus(): KomaMatrix = KomaMatrix(toKoma().origin.unaryMinus()) override fun add(a: Matrix, b: Matrix): KomaMatrix = KomaMatrix(a.toKoma().origin + b.toKoma().origin) override operator fun Matrix.minus(b: Matrix): KomaMatrix = KomaMatrix(toKoma().origin - b.toKoma().origin) override fun multiply(a: Matrix, k: Number): Matrix = produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } override operator fun Matrix.times(value: T): KomaMatrix = KomaMatrix(toKoma().origin * value) companion object } fun KomaMatrixContext.solve(a: Matrix, b: Matrix) = KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin)) fun KomaMatrixContext.solve(a: Matrix, b: Point) = KomaVector(a.toKoma().origin.solve(b.toKoma().origin)) fun KomaMatrixContext.inverse(a: Matrix) = KomaMatrix(a.toKoma().origin.inv()) class KomaMatrix(val origin: koma.matrix.Matrix, features: Set? = null) : FeaturedMatrix { override val rowNum: Int get() = origin.numRows() override val colNum: Int get() = origin.numCols() override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols()) override val features: Set = features ?: hashSetOf( object : DeterminantFeature { override val determinant: T get() = origin.det() }, object : LUPDecompositionFeature { private val lup by lazy { origin.LU() } override val l: FeaturedMatrix get() = KomaMatrix(lup.second) override val u: FeaturedMatrix get() = KomaMatrix(lup.third) override val p: FeaturedMatrix get() = KomaMatrix(lup.first) } ) override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix = KomaMatrix(this.origin, this.features + features) override operator fun get(i: Int, j: Int): T = origin.getGeneric(i, j) override fun equals(other: Any?): Boolean { return NDStructure.equals(this, other as? NDStructure<*> ?: return false) } override fun hashCode(): Int { var result = origin.hashCode() result = 31 * result + features.hashCode() return result } } class KomaVector internal constructor(val origin: koma.matrix.Matrix) : Point { override val size: Int get() = origin.numRows() init { require(origin.numCols() == 1) { error("Only single column matrices are allowed") } } override operator fun get(index: Int): T = origin.getGeneric(index) override operator fun iterator(): Iterator = origin.toIterable().iterator() }