package scientifik.kmath.linear import koma.extensions.fill import koma.matrix.MatrixFactory class KomaMatrixContext(val factory: MatrixFactory>) : MatrixContext, LinearSolver { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) = KomaMatrix(factory.zeros(rows, columns).fill(initializer)) fun Matrix.toKoma(): KomaMatrix = if (this is KomaMatrix) { this } else { produce(rowNum, colNum) { i, j -> get(i, j) } } fun Point.toKoma(): KomaVector = if (this is KomaVector) { this } else { KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) }) } override fun Matrix.dot(other: Matrix) = KomaMatrix(this.toKoma().origin * other.toKoma().origin) override fun Matrix.dot(vector: Point) = KomaVector(this.toKoma().origin * vector.toKoma().origin) override fun Matrix.unaryMinus() = KomaMatrix(this.toKoma().origin.unaryMinus()) override fun Matrix.plus(b: Matrix) = KomaMatrix(this.toKoma().origin + b.toKoma().origin) override fun Matrix.minus(b: Matrix) = KomaMatrix(this.toKoma().origin - b.toKoma().origin) override fun Matrix.times(value: T) = KomaMatrix(this.toKoma().origin * value) override fun solve(a: Matrix, b: Matrix) = KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin)) override fun inverse(a: Matrix) = KomaMatrix(a.toKoma().origin.inv()) } class KomaMatrix(val origin: koma.matrix.Matrix, features: Set? = null) : Matrix { override val rowNum: Int get() = origin.numRows() override val colNum: Int get() = origin.numCols() override val features: Set = features ?: setOf( object : DeterminantFeature { override val determinant: T get() = origin.det() }, object : LUPDecompositionFeature { private val lup by lazy { origin.LU() } override val l: Matrix get() = KomaMatrix(lup.second) override val u: Matrix get() = KomaMatrix(lup.third) override val p: Matrix get() = KomaMatrix(lup.first) } ) override fun suggestFeature(vararg features: MatrixFeature): Matrix = KomaMatrix(this.origin, this.features + features) override fun get(i: Int, j: Int): T = origin.getGeneric(i, j) } class KomaVector internal constructor(val origin: koma.matrix.Matrix) : Point { init { if (origin.numCols() != 1) error("Only single column matrices are allowed") } override val size: Int get() = origin.numRows() override fun get(index: Int): T = origin.getGeneric(index) override fun iterator(): Iterator = origin.toIterable().iterator() }