diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index 48b6e0ef1..850446afa 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -2,8 +2,8 @@ package kscience.kmath.commons.linear import kscience.kmath.linear.DiagonalFeature import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.MatrixWrapper import kscience.kmath.linear.Point +import kscience.kmath.linear.origin import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix import org.apache.commons.math3.linear.* @@ -47,9 +47,9 @@ public object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } - public fun Matrix.toCM(): CMMatrix = when { - this is CMMatrix -> this - this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toCM(): CMMatrix = when (val matrix = origin) { + is CMMatrix -> matrix else -> { //TODO add feature analysis val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt index 1db905bf2..362db1fe7 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -15,30 +15,37 @@ import kotlin.reflect.safeCast * * @param T the type of items. */ -public class MatrixWrapper( - public val matrix: Matrix, +public class MatrixWrapper internal constructor( + public val origin: Matrix, public val features: Set, -) : Matrix by matrix { +) : Matrix by origin { /** * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria */ @UnstableKMathAPI override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) - ?: matrix.getFeature(type) + ?: origin.getFeature(type) - override fun equals(other: Any?): Boolean = matrix == other - override fun hashCode(): Int = matrix.hashCode() + override fun equals(other: Any?): Boolean = origin == other + override fun hashCode(): Int = origin.hashCode() override fun toString(): String { - return "MatrixWrapper(matrix=$matrix, features=$features)" + return "MatrixWrapper(matrix=$origin, features=$features)" } } +/** + * Return the original matrix. If this is a wrapper, return its origin. If not, this matrix. + * Origin does not necessary store all features. + */ +@UnstableKMathAPI +public val Matrix.origin: Matrix get() = (this as? MatrixWrapper)?.origin ?: this + /** * Add a single feature to a [Matrix] */ public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(matrix, features + newFeature) + MatrixWrapper(origin, features + newFeature) } else { MatrixWrapper(this, setOf(newFeature)) } @@ -48,7 +55,7 @@ public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixW */ public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(matrix, features + newFeatures) + MatrixWrapper(origin, features + newFeatures) } else { MatrixWrapper(this, newFeatures.toSet()) } diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index 7198bbd0d..8184d0110 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -2,8 +2,8 @@ package kscience.kmath.ejml import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.MatrixWrapper import kscience.kmath.linear.Point +import kscience.kmath.linear.origin import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix import kscience.kmath.structures.getFeature @@ -19,9 +19,9 @@ public object EjmlMatrixContext : MatrixContext { /** * Converts this matrix to EJML one. */ - public fun Matrix.toEjml(): EjmlMatrix = when { - this is EjmlMatrix -> this - this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toEjml(): EjmlMatrix = when (val matrix = origin) { + is EjmlMatrix -> matrix else -> produce(rowNum, colNum) { i, j -> get(i, j) } }