From d0c9d97706fa21584af1e99984a2763810bcd4f9 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 19 Jan 2021 22:24:42 +0300 Subject: [PATCH 1/5] Minor optimization for RealNDAlgebra --- .../ast/ExpressionsInterpretersBenchmark.kt | 4 +- .../kscience/kmath/benchmarks/DotBenchmark.kt | 20 ++++----- .../kmath/benchmarks/LargeNDBenchmark.kt | 25 ----------- .../benchmarks/LinearAlgebraBenchmark.kt | 8 +--- .../FunctionalExpressionAlgebra.kt | 5 ++- .../kscience/kmath/linear/LupDecomposition.kt | 1 + .../kscience/kmath/operations/Complex.kt | 2 + .../kscience/kmath/structures/RealNDField.kt | 41 +++++++++---------- .../kscience/kmath/linear/MatrixTest.kt | 1 + .../kmath/structures/NumberNDFieldTest.kt | 1 + .../kscience/dimensions/DMatrixContextTest.kt | 1 + 11 files changed, 42 insertions(+), 67 deletions(-) delete mode 100644 examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt index 6acaca84d..c5edcdedf 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt @@ -28,7 +28,7 @@ internal class ExpressionsInterpretersBenchmark { @Benchmark fun mstExpression() { val expr = algebra.mstInField { - symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0 } invokeAndSum(expr) @@ -37,7 +37,7 @@ internal class ExpressionsInterpretersBenchmark { @Benchmark fun asmExpression() { val expr = algebra.mstInField { - symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) + symbol("x") * 2.0 + 2.0 / symbol("x") - 16.0 }.compile() invokeAndSum(expr) diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt index 8823e86db..5c59afaee 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt @@ -2,9 +2,8 @@ package kscience.kmath.benchmarks import kotlinx.benchmark.Benchmark import kscience.kmath.commons.linear.CMMatrixContext -import kscience.kmath.commons.linear.toCM import kscience.kmath.ejml.EjmlMatrixContext -import kscience.kmath.ejml.toEjml + import kscience.kmath.linear.BufferMatrixContext import kscience.kmath.linear.RealMatrixContext import kscience.kmath.linear.real @@ -26,11 +25,11 @@ class DotBenchmark { val matrix1 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } val matrix2 = Matrix.real(dim, dim) { i, j -> if (i <= j) random.nextDouble() else 0.0 } - val cmMatrix1 = matrix1.toCM() - val cmMatrix2 = matrix2.toCM() + val cmMatrix1 = CMMatrixContext { matrix1.toCM() } + val cmMatrix2 = CMMatrixContext { matrix2.toCM() } - val ejmlMatrix1 = matrix1.toEjml() - val ejmlMatrix2 = matrix2.toEjml() + val ejmlMatrix1 = EjmlMatrixContext { matrix1.toEjml() } + val ejmlMatrix2 = EjmlMatrixContext { matrix2.toEjml() } } @Benchmark @@ -49,22 +48,23 @@ class DotBenchmark { @Benchmark fun ejmlMultiplicationwithConversion() { - val ejmlMatrix1 = matrix1.toEjml() - val ejmlMatrix2 = matrix2.toEjml() EjmlMatrixContext { + val ejmlMatrix1 = matrix1.toEjml() + val ejmlMatrix2 = matrix2.toEjml() + ejmlMatrix1 dot ejmlMatrix2 } } @Benchmark fun bufferedMultiplication() { - BufferMatrixContext(RealField, Buffer.Companion::real).invoke{ + BufferMatrixContext(RealField, Buffer.Companion::real).invoke { matrix1 dot matrix2 } } @Benchmark - fun realMultiplication(){ + fun realMultiplication() { RealMatrixContext { matrix1 dot matrix2 } diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt deleted file mode 100644 index 395fde619..000000000 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt +++ /dev/null @@ -1,25 +0,0 @@ -package kscience.kmath.benchmarks - -import kscience.kmath.structures.NDField -import org.openjdk.jmh.annotations.Benchmark -import org.openjdk.jmh.annotations.Scope -import org.openjdk.jmh.annotations.State -import org.openjdk.jmh.infra.Blackhole -import kotlin.random.Random - -@State(Scope.Benchmark) -class LargeNDBenchmark { - val arraySize = 10000 - val RANDOM = Random(222) - val src1 = DoubleArray(arraySize) { RANDOM.nextDouble() } - val src2 = DoubleArray(arraySize) { RANDOM.nextDouble() } - val field = NDField.real(arraySize) - val kmathArray1 = field.produce { (a) -> src1[a] } - val kmathArray2 = field.produce { (a) -> src2[a] } - - @Benchmark - fun test10000(bh: Blackhole) { - bh.consume(field.add(kmathArray1, kmathArray2)) - } - -} \ No newline at end of file diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt index ec8714617..5ff43ef80 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LinearAlgebraBenchmark.kt @@ -5,10 +5,8 @@ import kotlinx.benchmark.Benchmark import kscience.kmath.commons.linear.CMMatrixContext import kscience.kmath.commons.linear.CMMatrixContext.dot import kscience.kmath.commons.linear.inverse -import kscience.kmath.commons.linear.toCM import kscience.kmath.ejml.EjmlMatrixContext import kscience.kmath.ejml.inverse -import kscience.kmath.ejml.toEjml import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix import org.openjdk.jmh.annotations.Scope @@ -35,16 +33,14 @@ class LinearAlgebraBenchmark { @Benchmark fun cmLUPInversion() { CMMatrixContext { - val cm = matrix.toCM() //avoid overhead on conversion - inverse(cm) + inverse(matrix) } } @Benchmark fun ejmlInverse() { EjmlMatrixContext { - val km = matrix.toEjml() //avoid overhead on conversion - inverse(km) + inverse(matrix) } } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 880a4e34c..1a3668855 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -95,8 +95,9 @@ public open class FunctionalExpressionRing>( super.binaryOperationFunction(operation) } -public open class FunctionalExpressionField>(algebra: A) : - FunctionalExpressionRing(algebra), Field> { +public open class FunctionalExpressionField>( + algebra: A, +) : FunctionalExpressionRing(algebra), Field> { /** * Builds an Expression of division an expression by another one. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt index f4f998da2..5cf7c8f70 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt @@ -224,6 +224,7 @@ public inline fun , F : Field> GenericMatrixContext ): Matrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) +@OptIn(UnstableKMathAPI::class) public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): Matrix { // Use existing decomposition if it is provided by matrix val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index 5695e6696..c6409c015 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -3,6 +3,7 @@ package kscience.kmath.operations import kscience.kmath.memory.MemoryReader import kscience.kmath.memory.MemorySpec import kscience.kmath.memory.MemoryWriter +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Buffer import kscience.kmath.structures.MemoryBuffer import kscience.kmath.structures.MutableBuffer @@ -41,6 +42,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0) /** * A field of [Complex]. */ +@OptIn(UnstableKMathAPI::class) public object ComplexField : ExtendedField, Norm, RingWithNumbers { override val zero: Complex = 0.0.toComplex() override val one: Complex = 1.0.toComplex() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt index 3eb1dc4ca..60e6de440 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt @@ -11,7 +11,7 @@ public typealias RealNDElement = BufferedNDFieldElement public class RealNDField(override val shape: IntArray) : BufferedNDField, ExtendedNDField>, - RingWithNumbers>{ + RingWithNumbers> { override val strides: Strides = DefaultStrides(shape) @@ -24,35 +24,31 @@ public class RealNDField(override val shape: IntArray) : return produce { d } } - private inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = - RealBuffer(DoubleArray(size) { initializer(it) }) - - /** - * Inline transform an NDStructure to - */ - override fun map( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun map( arg: NDBuffer, - transform: RealField.(Double) -> Double + transform: RealField.(Double) -> Double, ): RealNDElement { check(arg) - val array = buildBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } + val array = RealBuffer(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) } return BufferedNDFieldElement(this, array) } - override fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { - val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } + @Suppress("OVERRIDE_BY_INLINE") + override inline fun produce(initializer: RealField.(IntArray) -> Double): RealNDElement { + val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) } return BufferedNDFieldElement(this, array) } - override fun mapIndexed( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun mapIndexed( arg: NDBuffer, - transform: RealField.(index: IntArray, Double) -> Double + transform: RealField.(index: IntArray, Double) -> Double, ): RealNDElement { check(arg) - return BufferedNDFieldElement( this, - buildBuffer(arg.strides.linearSize) { offset -> + RealBuffer(arg.strides.linearSize) { offset -> elementContext.transform( arg.strides.index(offset), arg.buffer[offset] @@ -60,16 +56,17 @@ public class RealNDField(override val shape: IntArray) : }) } - override fun combine( + @Suppress("OVERRIDE_BY_INLINE") + override inline fun combine( a: NDBuffer, b: NDBuffer, - transform: RealField.(Double, Double) -> Double + transform: RealField.(Double, Double) -> Double, ): RealNDElement { check(a, b) - return BufferedNDFieldElement( - this, - buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) } - ) + val buffer = RealBuffer(strides.linearSize) { offset -> + elementContext.transform(a.buffer[offset], b.buffer[offset]) + } + return BufferedNDFieldElement(this, buffer) } override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt index 0a582e339..d7755dcb5 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/linear/MatrixTest.kt @@ -7,6 +7,7 @@ import kscience.kmath.structures.as2D import kotlin.test.Test import kotlin.test.assertEquals +@Suppress("UNUSED_VARIABLE") class MatrixTest { @Test fun testTranspose() { diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt index f5e008ef3..22a0d3629 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NumberNDFieldTest.kt @@ -8,6 +8,7 @@ import kotlin.math.pow import kotlin.test.Test import kotlin.test.assertEquals +@Suppress("UNUSED_VARIABLE") class NumberNDFieldTest { val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() } val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() } diff --git a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt index 5b330fcce..b9193d4dd 100644 --- a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt @@ -6,6 +6,7 @@ import kscience.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.one import kotlin.test.Test +@Suppress("UNUSED_VARIABLE") internal class DMatrixContextTest { @Test fun testDimensionSafeMatrix() { From 1c7bd05c584528cf569aa5f7f537c056e89e4a6e Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 19 Jan 2021 22:48:43 +0300 Subject: [PATCH 2/5] Add proper equality check for EJML matrices --- .../kscience/kmath/linear/BufferMatrix.kt | 2 +- .../kscience/kmath/structures/NDStructure.kt | 4 ++-- .../kmath/structures/LazyNDStructure.kt | 2 +- .../kotlin/kscience/kmath/ejml/EjmlMatrix.kt | 19 +++++++++++++++---- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt index 80460baca..a74d948fc 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt @@ -43,7 +43,7 @@ public class BufferMatrix( if (this === other) return true return when (other) { - is NDStructure<*> -> NDStructure.equals(this, other) + is NDStructure<*> -> NDStructure.contentEquals(this, other) else -> false } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt index 64e723581..e7d89ca7e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt @@ -54,7 +54,7 @@ public interface NDStructure { /** * Indicates whether some [NDStructure] is equal to another one. */ - public fun equals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { + public fun contentEquals(st1: NDStructure<*>, st2: NDStructure<*>): Boolean { if (st1 === st2) return true // fast comparison of buffers if possible @@ -275,7 +275,7 @@ public abstract class NDBuffer : NDStructure { override fun elements(): Sequence> = strides.indices().map { it to this[it] } override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false) } override fun hashCode(): Int { diff --git a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt index bb0d19c23..7aa746797 100644 --- a/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt +++ b/kmath-coroutines/src/jvmMain/kotlin/kscience/kmath/structures/LazyNDStructure.kt @@ -24,7 +24,7 @@ public class LazyNDStructure( } public override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + return NDStructure.contentEquals(this, other as? NDStructure<*> ?: return false) } public override fun hashCode(): Int { diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt index 1c2ded447..82a5399fd 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -3,6 +3,7 @@ package kscience.kmath.ejml import kscience.kmath.linear.* import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix +import kscience.kmath.structures.NDStructure import kscience.kmath.structures.RealBuffer import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix @@ -15,7 +16,7 @@ import kotlin.reflect.cast * @property origin the underlying [SimpleMatrix]. * @author Iaroslav Postovalov */ -public inline class EjmlMatrix( +public class EjmlMatrix( public val origin: SimpleMatrix, ) : Matrix { public override val rowNum: Int get() = origin.numRows() @@ -49,7 +50,7 @@ public inline class EjmlMatrix( override val q: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } override val r: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } } - CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { override val l: Matrix by lazy { val cholesky = DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } @@ -57,7 +58,7 @@ public inline class EjmlMatrix( EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature } } - LupDecompositionFeature::class -> object : LupDecompositionFeature { + LupDecompositionFeature::class -> object : LupDecompositionFeature { private val lup by lazy { DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } } @@ -73,7 +74,17 @@ public inline class EjmlMatrix( override val p: Matrix by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } } else -> null - }?.let{type.cast(it)} + }?.let { type.cast(it) } public override operator fun get(i: Int, j: Int): Double = origin[i, j] + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is Matrix<*>) return false + return NDStructure.contentEquals(this, other) + } + + override fun hashCode(): Int = origin.hashCode() + + } From d00e7434a4e14a5fbc6039f823b02fffd7b86d47 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 20 Jan 2021 15:07:39 +0300 Subject: [PATCH 3/5] Fix for #193 --- .../src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt index bbe9c1195..1db905bf2 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -25,6 +25,7 @@ public class MatrixWrapper( */ @UnstableKMathAPI override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) + ?: matrix.getFeature(type) override fun equals(other: Any?): Boolean = matrix == other override fun hashCode(): Int = matrix.hashCode() From 881b85a1d9dabc30c5d5baa8821d170800e64700 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 20 Jan 2021 15:32:55 +0300 Subject: [PATCH 4/5] Add `origin` (optin) extension property to expose MatrixWrapper content --- .../kscience/kmath/commons/linear/CMMatrix.kt | 8 +++--- .../kscience/kmath/linear/MatrixWrapper.kt | 25 ++++++++++++------- .../kscience/kmath/ejml/EjmlMatrixContext.kt | 8 +++--- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index 48b6e0ef1..850446afa 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -2,8 +2,8 @@ package kscience.kmath.commons.linear import kscience.kmath.linear.DiagonalFeature import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.MatrixWrapper import kscience.kmath.linear.Point +import kscience.kmath.linear.origin import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix import org.apache.commons.math3.linear.* @@ -47,9 +47,9 @@ public object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } - public fun Matrix.toCM(): CMMatrix = when { - this is CMMatrix -> this - this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toCM(): CMMatrix = when (val matrix = origin) { + is CMMatrix -> matrix else -> { //TODO add feature analysis val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt index 1db905bf2..362db1fe7 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -15,30 +15,37 @@ import kotlin.reflect.safeCast * * @param T the type of items. */ -public class MatrixWrapper( - public val matrix: Matrix, +public class MatrixWrapper internal constructor( + public val origin: Matrix, public val features: Set, -) : Matrix by matrix { +) : Matrix by origin { /** * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria */ @UnstableKMathAPI override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) - ?: matrix.getFeature(type) + ?: origin.getFeature(type) - override fun equals(other: Any?): Boolean = matrix == other - override fun hashCode(): Int = matrix.hashCode() + override fun equals(other: Any?): Boolean = origin == other + override fun hashCode(): Int = origin.hashCode() override fun toString(): String { - return "MatrixWrapper(matrix=$matrix, features=$features)" + return "MatrixWrapper(matrix=$origin, features=$features)" } } +/** + * Return the original matrix. If this is a wrapper, return its origin. If not, this matrix. + * Origin does not necessary store all features. + */ +@UnstableKMathAPI +public val Matrix.origin: Matrix get() = (this as? MatrixWrapper)?.origin ?: this + /** * Add a single feature to a [Matrix] */ public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(matrix, features + newFeature) + MatrixWrapper(origin, features + newFeature) } else { MatrixWrapper(this, setOf(newFeature)) } @@ -48,7 +55,7 @@ public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixW */ public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = if (this is MatrixWrapper) { - MatrixWrapper(matrix, features + newFeatures) + MatrixWrapper(origin, features + newFeatures) } else { MatrixWrapper(this, newFeatures.toSet()) } diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index 7198bbd0d..8184d0110 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -2,8 +2,8 @@ package kscience.kmath.ejml import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.MatrixWrapper import kscience.kmath.linear.Point +import kscience.kmath.linear.origin import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix import kscience.kmath.structures.getFeature @@ -19,9 +19,9 @@ public object EjmlMatrixContext : MatrixContext { /** * Converts this matrix to EJML one. */ - public fun Matrix.toEjml(): EjmlMatrix = when { - this is EjmlMatrix -> this - this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix + @OptIn(UnstableKMathAPI::class) + public fun Matrix.toEjml(): EjmlMatrix = when (val matrix = origin) { + is EjmlMatrix -> matrix else -> produce(rowNum, colNum) { i, j -> get(i, j) } } From d10ae66e580f7d87410d44cc082b5adb79554769 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 20 Jan 2021 17:08:29 +0300 Subject: [PATCH 5/5] Deploy fixes for 0.2.0-dev-5 --- .../src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt | 2 +- .../src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt | 3 +++ .../src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 2bafd377e..e7eb2770d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -92,7 +92,7 @@ public interface Algebra { * Call a block with an [Algebra] as receiver. */ // TODO add contract when KT-32313 is fixed -public inline operator fun , R> A.invoke(block: A.() -> R): R = block() +public inline operator fun , R> A.invoke(block: A.() -> R): R = run(block) /** * Represents "semispace", i.e. algebraic structure with associative binary operation called "addition" as well as diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt index 30d146779..455b52d9d 100644 --- a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -4,6 +4,7 @@ import kscience.kmath.linear.DeterminantFeature import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.MatrixFeature import kscience.kmath.linear.plus +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.getFeature import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix @@ -39,6 +40,7 @@ internal class EjmlMatrixTest { assertEquals(listOf(m.numRows(), m.numCols()), w.shape.toList()) } + @OptIn(UnstableKMathAPI::class) @Test fun features() { val m = randomMatrix @@ -57,6 +59,7 @@ internal class EjmlMatrixTest { private object SomeFeature : MatrixFeature {} + @OptIn(UnstableKMathAPI::class) @Test fun suggestFeature() { assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature()) diff --git a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt index c2304070f..4e29e6105 100644 --- a/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt +++ b/kmath-stat/src/jvmTest/kotlin/kscience/kmath/stat/MCScopeTest.kt @@ -62,6 +62,7 @@ class MCScopeTest { } + @OptIn(ObsoleteCoroutinesApi::class) fun compareResult(test: ATest) { val res1 = runBlocking(Dispatchers.Default) { test() } val res2 = runBlocking(newSingleThreadContext("test")) { test() }