diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index c9b3f7117..16034e5f9 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -77,9 +77,9 @@ public object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } - public fun Matrix.toCM(): CMMatrix = when { - this is CMMatrix -> this - this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toCM(): CMMatrix = when (val matrix = origin) { + is CMMatrix -> matrix else -> { //TODO add feature analysis val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt index bbe9c1195..362db1fe7 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -15,29 +15,37 @@ import kotlin.reflect.safeCast * * @param T the type of items. */ -public class MatrixWrapper( - public val matrix: Matrix, +public class MatrixWrapper internal constructor( + public val origin: Matrix, public val features: Set, -) : Matrix by matrix { +) : Matrix by origin { /** * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria */ @UnstableKMathAPI override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) + ?: origin.getFeature(type) - override fun equals(other: Any?): Boolean = matrix == other - override fun hashCode(): Int = matrix.hashCode() + override fun equals(other: Any?): Boolean = origin == other + override fun hashCode(): Int = origin.hashCode() override fun toString(): String { - return "MatrixWrapper(matrix=$matrix, features=$features)" + return "MatrixWrapper(matrix=$origin, features=$features)" } } +/** + * Return the original matrix. If this is a wrapper, return its origin. If not, this matrix. + * Origin does not necessary store all features. + */ +@UnstableKMathAPI +public val Matrix.origin: Matrix get() = (this as? MatrixWrapper)?.origin ?: this + /** * Add a single feature to a [Matrix] */ public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(matrix, features + newFeature) + MatrixWrapper(origin, features + newFeature) } else { MatrixWrapper(this, setOf(newFeature)) } @@ -47,7 +55,7 @@ public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixW */ public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(matrix, features + newFeatures) + MatrixWrapper(origin, features + newFeatures) } else { MatrixWrapper(this, newFeatures.toSet()) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 2bafd377e..e7eb2770d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -92,7 +92,7 @@ public interface Algebra { * Call a block with an [Algebra] as receiver. */ // TODO add contract when KT-32313 is fixed -public inline operator fun , R> A.invoke(block: A.() -> R): R = block() +public inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index 7198bbd0d..8184d0110 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -2,8 +2,8 @@ package kscience.kmath.ejml import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.MatrixWrapper import kscience.kmath.linear.Point +import kscience.kmath.linear.origin import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix import kscience.kmath.structures.getFeature @@ -19,9 +19,9 @@ public object EjmlMatrixContext : MatrixContext { /** * Converts this matrix to EJML one. */ - public fun Matrix.toEjml(): EjmlMatrix = when { - this is EjmlMatrix -> this - this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toEjml(): EjmlMatrix = when (val matrix = origin) { + is EjmlMatrix -> matrix else -> produce(rowNum, colNum) { i, j -> get(i, j) } } diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt index 30d146779..455b52d9d 100644 --- a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -4,6 +4,7 @@ import kscience.kmath.linear.DeterminantFeature import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.MatrixFeature import kscience.kmath.linear.plus +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.getFeature import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix @@ -39,6 +40,7 @@ internal class EjmlMatrixTest { assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) } + @OptIn(UnstableKMathAPI::class) @Test fun features() { val m = randomMatrix @@ -57,6 +59,7 @@ internal class EjmlMatrixTest { private object SomeFeature : MatrixFeature {} + @OptIn(UnstableKMathAPI::class) @Test fun suggestFeature() { assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature()) diff --git a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt index c2304070f..4e29e6105 100644 --- a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt @@ -62,6 +62,7 @@ class MCScopeTest { } + @OptIn(ObsoleteCoroutinesApi::class) fun compareResult(test: ATest) { val res1 = runBlocking(Dispatchers.Default) { test() } val res2 = runBlocking(newSingleThreadContext("test")) { test() }