2019-01-21 10:22:37 +03:00
|
|
|
package scientifik.kmath.linear
|
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
import koma.extensions.fill
|
|
|
|
import koma.matrix.MatrixFactory
|
2019-04-15 20:50:50 +03:00
|
|
|
import scientifik.kmath.operations.Space
|
2019-04-07 12:27:49 +03:00
|
|
|
import scientifik.kmath.structures.Matrix
|
2020-05-06 10:50:08 +03:00
|
|
|
import scientifik.kmath.structures.NDStructure
|
2019-01-21 10:22:37 +03:00
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
class KomaMatrixContext<T : Any>(
|
|
|
|
private val factory: MatrixFactory<koma.matrix.Matrix<T>>,
|
|
|
|
private val space: Space<T>
|
|
|
|
) :
|
|
|
|
MatrixContext<T> {
|
2019-01-21 10:22:37 +03:00
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T) =
|
|
|
|
KomaMatrix(factory.zeros(rows, columns).fill(initializer))
|
|
|
|
|
|
|
|
fun Matrix<T>.toKoma(): KomaMatrix<T> = if (this is KomaMatrix) {
|
|
|
|
this
|
|
|
|
} else {
|
|
|
|
produce(rowNum, colNum) { i, j -> get(i, j) }
|
2019-01-21 10:22:37 +03:00
|
|
|
}
|
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
fun Point<T>.toKoma(): KomaVector<T> = if (this is KomaVector) {
|
|
|
|
this
|
|
|
|
} else {
|
|
|
|
KomaVector(factory.zeros(size, 1).fill { i, _ -> get(i) })
|
2019-01-21 10:22:37 +03:00
|
|
|
}
|
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
|
|
|
|
override fun Matrix<T>.dot(other: Matrix<T>) =
|
|
|
|
KomaMatrix(this.toKoma().origin * other.toKoma().origin)
|
|
|
|
|
|
|
|
override fun Matrix<T>.dot(vector: Point<T>) =
|
|
|
|
KomaVector(this.toKoma().origin * vector.toKoma().origin)
|
|
|
|
|
|
|
|
override fun Matrix<T>.unaryMinus() =
|
|
|
|
KomaMatrix(this.toKoma().origin.unaryMinus())
|
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
override fun add(a: Matrix<T>, b: Matrix<T>) =
|
|
|
|
KomaMatrix(a.toKoma().origin + b.toKoma().origin)
|
2019-01-21 15:05:44 +03:00
|
|
|
|
|
|
|
override fun Matrix<T>.minus(b: Matrix<T>) =
|
|
|
|
KomaMatrix(this.toKoma().origin - b.toKoma().origin)
|
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
override fun multiply(a: Matrix<T>, k: Number): Matrix<T> =
|
|
|
|
produce(a.rowNum, a.colNum) { i, j -> space.run { a[i, j] * k } }
|
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
override fun Matrix<T>.times(value: T) =
|
|
|
|
KomaMatrix(this.toKoma().origin * value)
|
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
companion object {
|
2019-01-21 15:05:44 +03:00
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
}
|
2019-01-21 15:05:44 +03:00
|
|
|
|
2019-01-21 10:22:37 +03:00
|
|
|
}
|
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Matrix<T>) =
|
|
|
|
KomaMatrix(a.toKoma().origin.solve(b.toKoma().origin))
|
|
|
|
|
|
|
|
fun <T : Any> KomaMatrixContext<T>.solve(a: Matrix<T>, b: Point<T>) =
|
|
|
|
KomaVector(a.toKoma().origin.solve(b.toKoma().origin))
|
|
|
|
|
|
|
|
fun <T : Any> KomaMatrixContext<T>.inverse(a: Matrix<T>) =
|
|
|
|
KomaMatrix(a.toKoma().origin.inv())
|
|
|
|
|
2019-04-07 12:27:49 +03:00
|
|
|
class KomaMatrix<T : Any>(val origin: koma.matrix.Matrix<T>, features: Set<MatrixFeature>? = null) : FeaturedMatrix<T> {
|
2019-01-21 15:05:44 +03:00
|
|
|
override val rowNum: Int get() = origin.numRows()
|
|
|
|
override val colNum: Int get() = origin.numCols()
|
2019-01-27 09:45:32 +03:00
|
|
|
|
2019-04-15 20:50:50 +03:00
|
|
|
override val shape: IntArray get() = intArrayOf(origin.numRows(), origin.numCols())
|
2019-04-07 12:27:49 +03:00
|
|
|
|
2019-01-27 09:45:32 +03:00
|
|
|
override val features: Set<MatrixFeature> = features ?: setOf(
|
|
|
|
object : DeterminantFeature<T> {
|
|
|
|
override val determinant: T get() = origin.det()
|
|
|
|
},
|
|
|
|
object : LUPDecompositionFeature<T> {
|
|
|
|
private val lup by lazy { origin.LU() }
|
2019-04-07 12:27:49 +03:00
|
|
|
override val l: FeaturedMatrix<T> get() = KomaMatrix(lup.second)
|
|
|
|
override val u: FeaturedMatrix<T> get() = KomaMatrix(lup.third)
|
|
|
|
override val p: FeaturedMatrix<T> get() = KomaMatrix(lup.first)
|
2019-01-27 09:45:32 +03:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2019-04-07 12:27:49 +03:00
|
|
|
override fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T> =
|
2019-01-27 09:45:32 +03:00
|
|
|
KomaMatrix(this.origin, this.features + features)
|
2019-01-21 10:22:37 +03:00
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
override fun get(i: Int, j: Int): T = origin.getGeneric(i, j)
|
2020-05-06 10:50:08 +03:00
|
|
|
|
|
|
|
override fun equals(other: Any?): Boolean {
|
|
|
|
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
|
|
|
}
|
|
|
|
|
|
|
|
override fun hashCode(): Int {
|
|
|
|
var result = origin.hashCode()
|
|
|
|
result = 31 * result + features.hashCode()
|
|
|
|
return result
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-01-21 15:05:44 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
class KomaVector<T : Any> internal constructor(val origin: koma.matrix.Matrix<T>) : Point<T> {
|
|
|
|
init {
|
|
|
|
if (origin.numCols() != 1) error("Only single column matrices are allowed")
|
|
|
|
}
|
|
|
|
|
|
|
|
override val size: Int get() = origin.numRows()
|
|
|
|
|
|
|
|
override fun get(index: Int): T = origin.getGeneric(index)
|
|
|
|
|
|
|
|
override fun iterator(): Iterator<T> = origin.toIterable().iterator()
|
|
|
|
}
|
2019-01-21 10:22:37 +03:00
|
|
|
|