Merge branch 'dev' into commandertvis/cm-decompositions

# Conflicts:
#	kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt
This commit is contained in:
Iaroslav Postovalov 2021-01-21 18:43:03 +07:00
commit a3cf13b678
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
6 changed files with 28 additions and 16 deletions

View File

@ -77,9 +77,9 @@ public object CMMatrixContext : MatrixContext<Double, CMMatrix> {
return CMMatrix(Array2DRowRealMatrix(array)) return CMMatrix(Array2DRowRealMatrix(array))
} }
public fun Matrix<Double>.toCM(): CMMatrix = when { @OptIn(UnstableKMathAPI::class)
this is CMMatrix -> this public fun Matrix<Double>.toCM(): CMMatrix = when (val matrix = origin) {
this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix is CMMatrix -> matrix
else -> { else -> {
//TODO add feature analysis //TODO add feature analysis
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }

View File

@ -15,29 +15,37 @@ import kotlin.reflect.safeCast
* *
* @param T the type of items. * @param T the type of items.
*/ */
public class MatrixWrapper<T : Any>( public class MatrixWrapper<T : Any> internal constructor(
public val matrix: Matrix<T>, public val origin: Matrix<T>,
public val features: Set<MatrixFeature>, public val features: Set<MatrixFeature>,
) : Matrix<T> by matrix { ) : Matrix<T> by origin {
/** /**
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
*/ */
@UnstableKMathAPI @UnstableKMathAPI
override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) }) override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) })
?: origin.getFeature(type)
override fun equals(other: Any?): Boolean = matrix == other override fun equals(other: Any?): Boolean = origin == other
override fun hashCode(): Int = matrix.hashCode() override fun hashCode(): Int = origin.hashCode()
override fun toString(): String { override fun toString(): String {
return "MatrixWrapper(matrix=$matrix, features=$features)" return "MatrixWrapper(matrix=$origin, features=$features)"
} }
} }
/**
* Return the original matrix. If this is a wrapper, return its origin. If not, this matrix.
* Origin does not necessary store all features.
*/
@UnstableKMathAPI
public val <T : Any> Matrix<T>.origin: Matrix<T> get() = (this as? MatrixWrapper)?.origin ?: this
/** /**
* Add a single feature to a [Matrix] * Add a single feature to a [Matrix]
*/ */
public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) { public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) {
MatrixWrapper(matrix, features + newFeature) MatrixWrapper(origin, features + newFeature)
} else { } else {
MatrixWrapper(this, setOf(newFeature)) MatrixWrapper(this, setOf(newFeature))
} }
@ -47,7 +55,7 @@ public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixW
*/ */
public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> = public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> =
if (this is MatrixWrapper) { if (this is MatrixWrapper) {
MatrixWrapper(matrix, features + newFeatures) MatrixWrapper(origin, features + newFeatures)
} else { } else {
MatrixWrapper(this, newFeatures.toSet()) MatrixWrapper(this, newFeatures.toSet())
} }

View File

@ -92,7 +92,7 @@ public interface Algebra<T> {
* Call a block with an [Algebra] as receiver. * Call a block with an [Algebra] as receiver.
*/ */
// TODO add contract when KT-32313 is fixed // TODO add contract when KT-32313 is fixed
public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = block() public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = run(block)
/** /**
* Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as

View File

@ -2,8 +2,8 @@ package kscience.kmath.ejml
import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.InverseMatrixFeature
import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.MatrixWrapper
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
import kscience.kmath.linear.origin
import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import kscience.kmath.structures.getFeature import kscience.kmath.structures.getFeature
@ -19,9 +19,9 @@ public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
/** /**
* Converts this matrix to EJML one. * Converts this matrix to EJML one.
*/ */
public fun Matrix<Double>.toEjml(): EjmlMatrix = when { @OptIn(UnstableKMathAPI::class)
this is EjmlMatrix -> this public fun Matrix<Double>.toEjml(): EjmlMatrix = when (val matrix = origin) {
this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix is EjmlMatrix -> matrix
else -> produce(rowNum, colNum) { i, j -> get(i, j) } else -> produce(rowNum, colNum) { i, j -> get(i, j) }
} }

View File

@ -4,6 +4,7 @@ import kscience.kmath.linear.DeterminantFeature
import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.LupDecompositionFeature
import kscience.kmath.linear.MatrixFeature import kscience.kmath.linear.MatrixFeature
import kscience.kmath.linear.plus import kscience.kmath.linear.plus
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.getFeature import kscience.kmath.structures.getFeature
import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.dense.row.factory.DecompositionFactory_DDRM
import org.ejml.simple.SimpleMatrix import org.ejml.simple.SimpleMatrix
@ -39,6 +40,7 @@ internal class EjmlMatrixTest {
assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList())
} }
@OptIn(UnstableKMathAPI::class)
@Test @Test
fun features() { fun features() {
val m = randomMatrix val m = randomMatrix
@ -57,6 +59,7 @@ internal class EjmlMatrixTest {
private object SomeFeature : MatrixFeature {} private object SomeFeature : MatrixFeature {}
@OptIn(UnstableKMathAPI::class)
@Test @Test
fun suggestFeature() { fun suggestFeature() {
assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>()) assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())

View File

@ -62,6 +62,7 @@ class MCScopeTest {
} }
@OptIn(ObsoleteCoroutinesApi::class)
fun compareResult(test: ATest) { fun compareResult(test: ATest) {
val res1 = runBlocking(Dispatchers.Default) { test() } val res1 = runBlocking(Dispatchers.Default) { test() }
val res2 = runBlocking(newSingleThreadContext("test")) { test() } val res2 = runBlocking(newSingleThreadContext("test")) { test() }