From b288704528ca54ace552369ca068cd1bf1983992 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 7 Jan 2021 18:07:00 +0300 Subject: [PATCH 1/8] Optimize RealMatrix dot operation --- ...iplicationBenchmark.kt => DotBenchmark.kt} | 24 ++++-- .../kmath/structures/typeSafeDimensions.kt | 5 +- .../kscience/kmath/commons/linear/CMMatrix.kt | 1 + .../kscience/kmath/linear/BufferMatrix.kt | 24 +----- .../kscience/kmath/linear/LUPDecomposition.kt | 11 +-- .../kscience/kmath/linear/MatrixContext.kt | 14 ++-- .../kmath/linear/RealMatrixContext.kt | 84 +++++++++++++++++++ .../kscience/kmath/dimensions/Wrappers.kt | 38 ++++----- .../kscience/dimensions/DMatrixContextTest.kt | 1 + .../kotlin/kscience/kmath/real/RealMatrix.kt | 11 +-- 10 files changed, 135 insertions(+), 78 deletions(-) rename examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/{MultiplicationBenchmark.kt => DotBenchmark.kt} (73%) create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/MultiplicationBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt similarity index 73% rename from examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/MultiplicationBenchmark.kt rename to examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt index 9d2b02245..8823e86db 100644 --- a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/MultiplicationBenchmark.kt +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/DotBenchmark.kt @@ -2,19 +2,22 @@ package kscience.kmath.benchmarks import kotlinx.benchmark.Benchmark import kscience.kmath.commons.linear.CMMatrixContext -import kscience.kmath.commons.linear.CMMatrixContext.dot import kscience.kmath.commons.linear.toCM import kscience.kmath.ejml.EjmlMatrixContext import kscience.kmath.ejml.toEjml +import kscience.kmath.linear.BufferMatrixContext +import kscience.kmath.linear.RealMatrixContext import kscience.kmath.linear.real +import kscience.kmath.operations.RealField import kscience.kmath.operations.invoke +import kscience.kmath.structures.Buffer import kscience.kmath.structures.Matrix import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.State import kotlin.random.Random @State(Scope.Benchmark) -class MultiplicationBenchmark { +class DotBenchmark { companion object { val random = Random(12224) val dim = 1000 @@ -32,14 +35,14 @@ class MultiplicationBenchmark { @Benchmark fun commonsMathMultiplication() { - CMMatrixContext.invoke { + CMMatrixContext { cmMatrix1 dot cmMatrix2 } } @Benchmark fun ejmlMultiplication() { - EjmlMatrixContext.invoke { + EjmlMatrixContext { ejmlMatrix1 dot ejmlMatrix2 } } @@ -48,13 +51,22 @@ class MultiplicationBenchmark { fun ejmlMultiplicationwithConversion() { val ejmlMatrix1 = matrix1.toEjml() val ejmlMatrix2 = matrix2.toEjml() - EjmlMatrixContext.invoke { + EjmlMatrixContext { ejmlMatrix1 dot ejmlMatrix2 } } @Benchmark fun bufferedMultiplication() { - matrix1 dot matrix2 + BufferMatrixContext(RealField, Buffer.Companion::real).invoke{ + matrix1 dot matrix2 + } + } + + @Benchmark + fun realMultiplication(){ + RealMatrixContext { + matrix1 dot matrix2 + } } } \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt index 987eea16f..96684f7dc 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/typeSafeDimensions.kt @@ -4,9 +4,8 @@ import kscience.kmath.dimensions.D2 import kscience.kmath.dimensions.D3 import kscience.kmath.dimensions.DMatrixContext import kscience.kmath.dimensions.Dimension -import kscience.kmath.operations.RealField -private fun DMatrixContext.simple() { +private fun DMatrixContext.simple() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i + j).toDouble() } @@ -18,7 +17,7 @@ private object D5 : Dimension { override val dim: UInt = 5u } -private fun DMatrixContext.custom() { +private fun DMatrixContext.custom() { val m1 = produce { i, j -> (i + j).toDouble() } val m2 = produce { i, j -> (i - j).toDouble() } val m3 = produce { i, j -> (i - j).toDouble() } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index 712927400..49888f8d6 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -29,6 +29,7 @@ public class CMMatrix(public val origin: RealMatrix, features: Set.toCM(): CMMatrix = if (this is CMMatrix) { this } else { diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt index 8b50bbe33..402161207 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt @@ -1,8 +1,10 @@ package kscience.kmath.linear -import kscience.kmath.operations.RealField import kscience.kmath.operations.Ring -import kscience.kmath.structures.* +import kscience.kmath.structures.Buffer +import kscience.kmath.structures.BufferFactory +import kscience.kmath.structures.NDStructure +import kscience.kmath.structures.asSequence /** * Basic implementation of Matrix space based on [NDStructure] @@ -21,24 +23,6 @@ public class BufferMatrixContext>( public companion object } -@Suppress("OVERRIDE_BY_INLINE") -public object RealMatrixContext : GenericMatrixContext> { - public override val elementContext: RealField - get() = RealField - - public override inline fun produce( - rows: Int, - columns: Int, - initializer: (i: Int, j: Int) -> Double, - ): BufferMatrix { - val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } - return BufferMatrix(rows, columns, buffer) - } - - public override inline fun point(size: Int, initializer: (Int) -> Double): Point = - RealBuffer(size, initializer) -} - public class BufferMatrix( public override val rowNum: Int, public override val colNum: Int, diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt index 099fa1909..bf2a9f59e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt @@ -213,17 +213,8 @@ public inline fun , F : Field> GenericMatrixContext return decomposition.solveWithLUP(bufferFactory, b) } -public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): FeaturedMatrix = - solveWithLUP(a, b) { it < 1e-11 } - public inline fun , F : Field> GenericMatrixContext>.inverseWithLUP( matrix: Matrix, noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, noinline checkSingular: (T) -> Boolean, -): FeaturedMatrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) - -/** - * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. - */ -public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): FeaturedMatrix = - solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), Buffer.Companion::real) { it < 1e-11 } +): FeaturedMatrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index d9dc57b0f..9bc79e12b 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -18,6 +18,11 @@ public interface MatrixContext> : SpaceOperations T): M + /** + * Produce a point compatible with matrix space (and possibly optimized for it) + */ + public fun point(size: Int, initializer: (Int) -> T): Point = Buffer.boxing(size, initializer) + @Suppress("UNCHECKED_CAST") public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): M = when (operation) { "dot" -> left dot right @@ -61,10 +66,6 @@ public interface MatrixContext> : SpaceOperations): M = m * this public companion object { - /** - * Non-boxing double matrix - */ - public val real: RealMatrixContext = RealMatrixContext /** * A structured matrix with custom buffer @@ -88,11 +89,6 @@ public interface GenericMatrixContext, out M : Matrix> : */ public val elementContext: R - /** - * Produce a point compatible with matrix space - */ - public fun point(size: Int, initializer: (Int) -> T): Point - public override infix fun Matrix.dot(other: Matrix): M { //TODO add typed error require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt new file mode 100644 index 000000000..772b20f3b --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt @@ -0,0 +1,84 @@ +package kscience.kmath.linear + +import kscience.kmath.operations.RealField +import kscience.kmath.structures.Matrix +import kscience.kmath.structures.MutableBuffer +import kscience.kmath.structures.MutableBufferFactory +import kscience.kmath.structures.RealBuffer + +@Suppress("OVERRIDE_BY_INLINE") +public object RealMatrixContext : MatrixContext> { + + public override inline fun produce( + rows: Int, + columns: Int, + initializer: (i: Int, j: Int) -> Double, + ): BufferMatrix { + val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) } + return BufferMatrix(rows, columns, buffer) + } + + private fun Matrix.wrap(): BufferMatrix = if (this is BufferMatrix) this else { + produce(rowNum, colNum) { i, j -> get(i, j) } + } + + public fun one(rows: Int, columns: Int): FeaturedMatrix = VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> + if (i == j) 1.0 else 0.0 + } + + public override infix fun Matrix.dot(other: Matrix): BufferMatrix { + require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } + return produce(rowNum, other.colNum) { i, j -> + var res = 0.0 + for (l in 0 until colNum) { + res += get(i, l) * other.get(l, j) + } + res + } + } + + public override infix fun Matrix.dot(vector: Point): Point { + require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" } + return RealBuffer(rowNum) { i -> + var res = 0.0 + for (j in 0 until colNum) { + res += get(i, j) * vector[j] + } + res + } + } + + override fun add(a: Matrix, b: Matrix): BufferMatrix { + require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" } + require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" } + return produce(a.rowNum, a.colNum) { i, j -> + a[i, j] + b[i, j] + } + } + + override fun Matrix.times(value: Double): BufferMatrix = + produce(rowNum, colNum) { i, j -> get(i, j) * value } + + + override fun multiply(a: Matrix, k: Number): BufferMatrix = + produce(a.rowNum, a.colNum) { i, j -> a.get(i, j) * k.toDouble() } +} + + +/** + * Partially optimized real-valued matrix + */ +public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext + +public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): FeaturedMatrix { + // Use existing decomposition if it is provided by matrix + val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real + val decomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 } + return decomposition.solveWithLUP(bufferFactory, b) +} + +/** + * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. + */ +public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): FeaturedMatrix = + solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum)) diff --git a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt index 68a5dc262..0422d11b2 100644 --- a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt @@ -1,11 +1,6 @@ package kscience.kmath.dimensions -import kscience.kmath.linear.GenericMatrixContext -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.Point -import kscience.kmath.linear.transpose -import kscience.kmath.operations.RealField -import kscience.kmath.operations.Ring +import kscience.kmath.linear.* import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix import kscience.kmath.structures.Structure2D @@ -42,7 +37,7 @@ public interface DMatrix : Structure2D { * An inline wrapper for a Matrix */ public inline class DMatrixWrapper( - private val structure: Structure2D + private val structure: Structure2D, ) : DMatrix { override val shape: IntArray get() = structure.shape override operator fun get(i: Int, j: Int): T = structure[i, j] @@ -81,7 +76,7 @@ public inline class DPointWrapper(public val point: Point) /** * Basic operations on dimension-safe matrices. Operates on [Matrix] */ -public inline class DMatrixContext>(public val context: GenericMatrixContext>) { +public inline class DMatrixContext(public val context: MatrixContext>) { public inline fun Matrix.coerce(): DMatrix { require(rowNum == Dimension.dim().toInt()) { "Row number mismatch: expected ${Dimension.dim()} but found $rowNum" @@ -115,7 +110,7 @@ public inline class DMatrixContext>(public val context: Ge } public inline infix fun DMatrix.dot( - other: DMatrix + other: DMatrix, ): DMatrix = context { this@dot dot other }.coerce() public inline infix fun DMatrix.dot(vector: DPoint): DPoint = @@ -139,18 +134,19 @@ public inline class DMatrixContext>(public val context: Ge public inline fun DMatrix.transpose(): DMatrix = context { (this@transpose as Matrix).transpose() }.coerce() - /** - * A square unit matrix - */ - public inline fun one(): DMatrix = produce { i, j -> - if (i == j) context.elementContext.one else context.elementContext.zero - } - - public inline fun zero(): DMatrix = produce { _, _ -> - context.elementContext.zero - } - public companion object { - public val real: DMatrixContext = DMatrixContext(MatrixContext.real) + public val real: DMatrixContext = DMatrixContext(MatrixContext.real) } } + + +/** + * A square unit matrix + */ +public inline fun DMatrixContext.one(): DMatrix = produce { i, j -> + if (i == j) 1.0 else 0.0 +} + +public inline fun DMatrixContext.zero(): DMatrix = produce { _, _ -> + 0.0 +} \ No newline at end of file diff --git a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt index f44b16753..5b330fcce 100644 --- a/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt +++ b/kmath-dimensions/src/commonTest/kotlin/kscience/dimensions/DMatrixContextTest.kt @@ -3,6 +3,7 @@ package kscience.dimensions import kscience.kmath.dimensions.D2 import kscience.kmath.dimensions.D3 import kscience.kmath.dimensions.DMatrixContext +import kscience.kmath.dimensions.one import kotlin.test.Test internal class DMatrixContextTest { diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt index e8ad835e5..772abfbed 100644 --- a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt @@ -1,13 +1,7 @@ package kscience.kmath.real -import kscience.kmath.linear.FeaturedMatrix -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.RealMatrixContext.elementContext -import kscience.kmath.linear.VirtualMatrix -import kscience.kmath.linear.inverseWithLUP +import kscience.kmath.linear.* import kscience.kmath.misc.UnstableKMathAPI -import kscience.kmath.operations.invoke -import kscience.kmath.operations.sum import kscience.kmath.structures.Buffer import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.asIterable @@ -122,8 +116,7 @@ public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix = extractColumns(columnIndex..columnIndex) public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j -> - val column = columns[j] - elementContext { sum(column.asIterable()) } + columns[j].asIterable().sum() } public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> From 2012d2c3f1a5a81e1a54fa98baef501c48f8321d Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 7 Jan 2021 22:40:30 +0700 Subject: [PATCH 2/8] Fix #172, add constant folding for unary operations from numeric nodes --- .../kotlin/kscience/kmath/ast/MST.kt | 24 +++++++------- .../kotlin/kscience/kmath/estree/estree.kt | 30 ++++++++++-------- .../jvmMain/kotlin/kscience/kmath/asm/asm.kt | 31 +++++++++++-------- .../kscience/kmath/asm/internal/AsmBuilder.kt | 2 +- 4 files changed, 48 insertions(+), 39 deletions(-) diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt index 6cf746722..212fd0d0b 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MST.kt @@ -2,10 +2,9 @@ package kscience.kmath.ast import kscience.kmath.operations.Algebra import kscience.kmath.operations.NumericAlgebra -import kscience.kmath.operations.RealField /** - * A Mathematical Syntax Tree node for mathematical expressions. + * A Mathematical Syntax Tree (MST) node for mathematical expressions. * * @author Alexander Nozik */ @@ -57,21 +56,22 @@ public fun Algebra.evaluate(node: MST): T = when (node) { ?: error("Numeric nodes are not supported by $this") is MST.Symbolic -> symbol(node.value) - is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value)) + + is MST.Unary -> when { + this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value)) + else -> unaryOperationFunction(node.operation)(evaluate(node.value)) + } is MST.Binary -> when { - this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) + this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> + binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value)) - node.left is MST.Numeric && node.right is MST.Numeric -> { - val number = RealField - .binaryOperationFunction(node.operation) - .invoke(node.left.value.toDouble(), node.right.value.toDouble()) + this is NumericAlgebra && node.left is MST.Numeric -> + leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right)) - number(number) - } + this is NumericAlgebra && node.right is MST.Numeric -> + rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value) - node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right)) - node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value) else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right)) } } diff --git a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt index 159c5d5ec..5c08ada31 100644 --- a/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/kscience/kmath/estree/estree.kt @@ -1,18 +1,18 @@ package kscience.kmath.estree import kscience.kmath.ast.MST +import kscience.kmath.ast.MST.* import kscience.kmath.ast.MstExpression import kscience.kmath.estree.internal.ESTreeBuilder import kscience.kmath.estree.internal.estree.BaseExpression import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra import kscience.kmath.operations.NumericAlgebra -import kscience.kmath.operations.RealField @PublishedApi internal fun MST.compileWith(algebra: Algebra): Expression { fun ESTreeBuilder.visit(node: MST): BaseExpression = when (node) { - is MST.Symbolic -> { + is Symbolic -> { val symbol = try { algebra.symbol(node.value) } catch (ignored: IllegalStateException) { @@ -25,25 +25,29 @@ internal fun MST.compileWith(algebra: Algebra): Expression { variable(node.value) } - is MST.Numeric -> constant(node.value) - is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value)) + is Numeric -> constant(node.value) - is MST.Binary -> when { - algebra is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> constant( - algebra.number( - RealField - .binaryOperationFunction(node.operation) - .invoke(node.left.value.toDouble(), node.right.value.toDouble()) - ) + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> constant( + algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value))) + + else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value)) + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant( + algebra + .binaryOperationFunction(node.operation) + .invoke(algebra.number(node.left.value), algebra.number(node.right.value)) ) - algebra is NumericAlgebra && node.left is MST.Numeric -> call( + algebra is NumericAlgebra && node.left is Numeric -> call( algebra.leftSideNumberOperationFunction(node.operation), visit(node.left), visit(node.right), ) - algebra is NumericAlgebra && node.right is MST.Numeric -> call( + algebra is NumericAlgebra && node.right is Numeric -> call( algebra.rightSideNumberOperationFunction(node.operation), visit(node.left), visit(node.right), diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt index b98c0bfce..55cdec243 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/asm.kt @@ -3,11 +3,11 @@ package kscience.kmath.asm import kscience.kmath.asm.internal.AsmBuilder import kscience.kmath.asm.internal.buildName import kscience.kmath.ast.MST +import kscience.kmath.ast.MST.* import kscience.kmath.ast.MstExpression import kscience.kmath.expressions.Expression import kscience.kmath.operations.Algebra import kscience.kmath.operations.NumericAlgebra -import kscience.kmath.operations.RealField /** * Compiles given MST to an Expression using AST compiler. @@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField @PublishedApi internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { fun AsmBuilder.visit(node: MST): Unit = when (node) { - is MST.Symbolic -> { + is Symbolic -> { val symbol = try { algebra.symbol(node.value) } catch (ignored: IllegalStateException) { @@ -33,24 +33,29 @@ internal fun MST.compileWith(type: Class, algebra: Algebra): Exp loadVariable(node.value) } - is MST.Numeric -> loadNumberConstant(node.value) - is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) } + is Numeric -> loadNumberConstant(node.value) - is MST.Binary -> when { - algebra is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant( - algebra.number( - RealField - .binaryOperationFunction(node.operation) - .invoke(node.left.value.toDouble(), node.right.value.toDouble()) - ) + is Unary -> when { + algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant( + algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value))) + + else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) } + } + + is Binary -> when { + algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant( + algebra.binaryOperationFunction(node.operation) + .invoke(algebra.number(node.left.value), algebra.number(node.right.value)) ) - algebra is NumericAlgebra && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) { + algebra is NumericAlgebra && node.left is Numeric -> buildCall( + algebra.leftSideNumberOperationFunction(node.operation)) { visit(node.left) visit(node.right) } - algebra is NumericAlgebra && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) { + algebra is NumericAlgebra && node.right is Numeric -> buildCall( + algebra.rightSideNumberOperationFunction(node.operation)) { visit(node.left) visit(node.right) } diff --git a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt index 1edbed28d..93d8d1143 100644 --- a/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/kscience/kmath/asm/internal/AsmBuilder.kt @@ -191,7 +191,7 @@ internal class AsmBuilder( } val cls = classLoader.defineClass(className, classWriter.toByteArray()) - java.io.File("dump.class").writeBytes(classWriter.toByteArray()) + // java.io.File("dump.class").writeBytes(classWriter.toByteArray()) val l = MethodHandles.publicLookup() if (hasConstants) From 0ef7e7ca52ab6bedc1f33f6b5e2e1572c61b83f6 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sun, 10 Jan 2021 17:53:15 +0700 Subject: [PATCH 3/8] Update Gradle again --- gradle/wrapper/gradle-wrapper.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 4d9ca1649..da9702f9e 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.7.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.8-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists From 7fdd001a77e14aa745d629d528971b979a2f019a Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sat, 16 Jan 2021 15:51:36 +0700 Subject: [PATCH 4/8] Update KDoc comments for Matrix classes, improve MatrixFeature API, implement new features with EJML matrix, delete inversion API from EJML in favor of InverseMatrixFeature, override point by EJML matrix --- .../kscience/kmath/linear/FeaturedMatrix.kt | 10 +- .../kscience/kmath/linear/LUPDecomposition.kt | 14 ++- .../kscience/kmath/linear/MatrixFeatures.kt | 114 ++++++++++++++++-- .../kscience/kmath/structures/Structure2D.kt | 45 +++++-- .../kotlin/kscience/kmath/ejml/EjmlMatrix.kt | 83 ++++++++----- .../kscience/kmath/ejml/EjmlMatrixContext.kt | 15 +-- .../kscience/kmath/ejml/EjmlMatrixTest.kt | 4 +- 7 files changed, 217 insertions(+), 68 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt index 68272203c..119f5d844 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt @@ -7,10 +7,16 @@ import kscience.kmath.structures.asBuffer import kotlin.math.sqrt /** - * A 2d structure plus optional matrix-specific features + * A [Matrix] that holds [MatrixFeature] objects. + * + * @param T the type of items. */ public interface FeaturedMatrix : Matrix { - override val shape: IntArray get() = intArrayOf(rowNum, colNum) + public override val shape: IntArray get() = intArrayOf(rowNum, colNum) + + /** + * The set of features this matrix possesses. + */ public val features: Set /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt index bf2a9f59e..75acaf8a1 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt @@ -4,16 +4,15 @@ import kscience.kmath.operations.* import kscience.kmath.structures.* /** - * Common implementation of [LUPDecompositionFeature] + * Common implementation of [LupDecompositionFeature]. */ public class LUPDecomposition( public val context: MatrixContext>, public val elementContext: Field, - public val lu: Structure2D, + public val lu: Matrix, public val pivot: IntArray, private val even: Boolean, -) : LUPDecompositionFeature, DeterminantFeature { - +) : LupDecompositionFeature, DeterminantFeature { /** * Returns the matrix L of the decomposition. * @@ -151,7 +150,10 @@ public inline fun , F : Field> GenericMatrixContext public fun MatrixContext>.lup(matrix: Matrix): LUPDecomposition = lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 } -public fun LUPDecomposition.solveWithLUP(factory: MutableBufferFactory, matrix: Matrix): FeaturedMatrix { +public fun LUPDecomposition.solveWithLUP( + factory: MutableBufferFactory, + matrix: Matrix +): FeaturedMatrix { require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { @@ -217,4 +219,4 @@ public inline fun , F : Field> GenericMatrixContext matrix: Matrix, noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, noinline checkSingular: (T) -> Boolean, -): FeaturedMatrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) \ No newline at end of file +): FeaturedMatrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt index a82032e50..767b58eba 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt @@ -1,62 +1,154 @@ package kscience.kmath.linear /** - * A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix - * operations performance in some cases. + * A marker interface representing some properties of matrices or additional transformations of them. Features are used + * to optimize matrix operations performance in some cases or retrieve the APIs. */ public interface MatrixFeature /** - * The matrix with this feature is considered to have only diagonal non-null elements + * Matrices with this feature are considered to have only diagonal non-null elements. */ public object DiagonalFeature : MatrixFeature /** - * Matrix with this feature has all zero elements + * Matrices with this feature have all zero elements. */ public object ZeroFeature : MatrixFeature /** - * Matrix with this feature have unit elements on diagonal and zero elements in all other places + * Matrices with this feature have unit elements on diagonal and zero elements in all other places. */ public object UnitFeature : MatrixFeature /** - * Inverted matrix feature + * Matrices with this feature can be inverted: [inverse] = `a`-1 where `a` is the owning matrix. + * + * @param T the type of matrices' items. */ public interface InverseMatrixFeature : MatrixFeature { + /** + * The inverse matrix of the matrix that owns this feature. + */ public val inverse: FeaturedMatrix } /** - * A determinant container + * Matrices with this feature can compute their determinant. */ public interface DeterminantFeature : MatrixFeature { + /** + * The determinant of the matrix that owns this feature. + */ public val determinant: T } +/** + * Produces a [DeterminantFeature] where the [DeterminantFeature.determinant] is [determinant]. + * + * @param determinant the value of determinant. + * @return a new [DeterminantFeature]. + */ @Suppress("FunctionName") public fun DeterminantFeature(determinant: T): DeterminantFeature = object : DeterminantFeature { override val determinant: T = determinant } /** - * Lower triangular matrix + * Matrices with this feature are lower triangular ones. */ public object LFeature : MatrixFeature /** - * Upper triangular feature + * Matrices with this feature are upper triangular ones. */ public object UFeature : MatrixFeature /** - * TODO add documentation + * Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where + * *a* is the owning matrix. + * + * @param T the type of matrices' items. */ -public interface LUPDecompositionFeature : MatrixFeature { +public interface LupDecompositionFeature : MatrixFeature { + /** + * The lower triangular matrix in this decomposition. It may have [LFeature]. + */ public val l: FeaturedMatrix + + /** + * The upper triangular matrix in this decomposition. It may have [UFeature]. + */ public val u: FeaturedMatrix + + /** + * The permutation matrix in this decomposition. + */ public val p: FeaturedMatrix } +/** + * Matrices with this feature are orthogonal ones: *a · aT = u* where *a* is the owning matrix, *u* + * is the unit matrix ([UnitFeature]). + */ +public object OrthogonalFeature : MatrixFeature + +/** + * Matrices with this feature support QR factorization: *a = [q] · [r]* where *a* is the owning matrix. + * + * @param T the type of matrices' items. + */ +public interface QRDecompositionFeature : MatrixFeature { + /** + * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. + */ + public val q: FeaturedMatrix + + /** + * The upper triangular matrix in this decomposition. It may have [UFeature]. + */ + public val r: FeaturedMatrix +} + +/** + * Matrices with this feature support Cholesky factorization: *a = [l] · [l]H* where *a* is the + * owning matrix. + * + * @param T the type of matrices' items. + */ +public interface CholeskyDecompositionFeature : MatrixFeature { + /** + * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. + */ + public val l: FeaturedMatrix +} + +/** + * Matrices with this feature support SVD: *a = [u] · [s] · [v]H* where *a* is the owning + * matrix. + * + * @param T the type of matrices' items. + */ +public interface SingularValueDecompositionFeature : MatrixFeature { + /** + * The matrix in this decomposition. It is unitary, and it consists from left singular vectors. + */ + public val u: FeaturedMatrix + + /** + * The matrix in this decomposition. Its main diagonal elements are singular values. + */ + public val s: FeaturedMatrix + + /** + * The matrix in this decomposition. It is unitary, and it consists from right singular vectors. + */ + public val v: FeaturedMatrix + + /** + * The buffer of singular values of this SVD. + */ + public val singularValues: Point +} + //TODO add sparse matrix feature diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt index 25fdf3f3d..bac7d3389 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt @@ -1,12 +1,40 @@ package kscience.kmath.structures /** - * A structure that is guaranteed to be two-dimensional + * A structure that is guaranteed to be two-dimensional. + * + * @param T the type of items. */ public interface Structure2D : NDStructure { + /** + * The number of rows in this structure. + */ public val rowNum: Int get() = shape[0] + + /** + * The number of columns in this structure. + */ public val colNum: Int get() = shape[1] + /** + * The buffer of rows of this structure. It gets elements from the structure dynamically. + */ + public val rows: Buffer> + get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } + + /** + * The buffer of columns of this structure. It gets elements from the structure dynamically. + */ + public val columns: Buffer> + get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } + + /** + * Retrieves an element from the structure by two indices. + * + * @param i the first index. + * @param j the second index. + * @return an element. + */ public operator fun get(i: Int, j: Int): T override operator fun get(index: IntArray): T { @@ -14,15 +42,9 @@ public interface Structure2D : NDStructure { return get(index[0], index[1]) } - public val rows: Buffer> - get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } } - - public val columns: Buffer> - get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } } - override fun elements(): Sequence> = sequence { - for (i in (0 until rowNum)) - for (j in (0 until colNum)) yield(intArrayOf(i, j) to get(i, j)) + for (i in 0 until rowNum) + for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j)) } public companion object @@ -47,4 +69,9 @@ public fun NDStructure.as2D(): Structure2D = if (shape.size == 2) else error("Can't create 2d-structure from ${shape.size}d-structure") +/** + * Alias for [Structure2D] with more familiar name. + * + * @param T the type of items. + */ public typealias Matrix = Structure2D diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt index ed6b1571e..5b7d0a01b 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -1,12 +1,10 @@ package kscience.kmath.ejml +import kscience.kmath.linear.* +import kscience.kmath.structures.NDStructure +import kscience.kmath.structures.RealBuffer import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix -import kscience.kmath.linear.DeterminantFeature -import kscience.kmath.linear.FeaturedMatrix -import kscience.kmath.linear.LUPDecompositionFeature -import kscience.kmath.linear.MatrixFeature -import kscience.kmath.structures.NDStructure /** * Represents featured matrix over EJML [SimpleMatrix]. @@ -14,42 +12,71 @@ import kscience.kmath.structures.NDStructure * @property origin the underlying [SimpleMatrix]. * @author Iaroslav Postovalov */ -public class EjmlMatrix(public val origin: SimpleMatrix, features: Set? = null) : FeaturedMatrix { +public class EjmlMatrix(public val origin: SimpleMatrix, features: Set = emptySet()) : + FeaturedMatrix { public override val rowNum: Int get() = origin.numRows() public override val colNum: Int get() = origin.numCols() - public override val shape: IntArray - get() = intArrayOf(origin.numRows(), origin.numCols()) + public override val shape: IntArray by lazy { intArrayOf(rowNum, colNum) } - public override val features: Set = setOf( - object : LUPDecompositionFeature, DeterminantFeature { - override val determinant: Double - get() = origin.determinant() + public override val features: Set = hashSetOf( + object : InverseMatrixFeature { + override val inverse: FeaturedMatrix by lazy { EjmlMatrix(origin.invert()) } + }, - private val lup by lazy { - val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()) - .also { it.decompose(origin.ddrm.copy()) } + object : DeterminantFeature { + override val determinant: Double by lazy(origin::determinant) + }, - Triple( - EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))), - EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))), - EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))), - ) + object : SingularValueDecompositionFeature { + private val svd by lazy { + DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false) + .apply { decompose(origin.ddrm.copy()) } } - override val l: FeaturedMatrix - get() = lup.second + override val u: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) } + override val s: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) } + override val v: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } + override val singularValues: Point by lazy { RealBuffer(svd.singularValues) } + }, - override val u: FeaturedMatrix - get() = lup.third + object : QRDecompositionFeature { + private val qr by lazy { + DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) } + } - override val p: FeaturedMatrix - get() = lup.first - } - ) union features.orEmpty() + override val q: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } + override val r: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } + }, + + object : CholeskyDecompositionFeature { + override val l: FeaturedMatrix by lazy { + val cholesky = + DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } + + EjmlMatrix(SimpleMatrix(cholesky.getT(null)), setOf(LFeature)) + } + }, + + object : LupDecompositionFeature { + private val lup by lazy { + DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } + } + + override val l: FeaturedMatrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getLower(null)), setOf(LFeature)) + } + + override val u: FeaturedMatrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getUpper(null)), setOf(UFeature)) + } + + override val p: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } + }, + ) union features public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix = EjmlMatrix(origin, this.features + features) diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index 31792e39c..f8791b72c 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -17,7 +17,6 @@ public fun Matrix.toEjml(): EjmlMatrix = * @author Iaroslav Postovalov */ public object EjmlMatrixContext : MatrixContext { - /** * Converts this vector to EJML one. */ @@ -33,6 +32,11 @@ public object EjmlMatrixContext : MatrixContext { } }) + override fun point(size: Int, initializer: (Int) -> Double): Point = + EjmlVector(SimpleMatrix(size, 1).also { + (0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) } + }) + public override fun Matrix.dot(other: Matrix): EjmlMatrix = EjmlMatrix(toEjml().origin.mult(other.toEjml().origin)) @@ -73,12 +77,3 @@ public fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMa */ public fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) - -/** - * Returns the inverse of given matrix: b = a^(-1). - * - * @param a the matrix. - * @return the inverse of this matrix. - * @author Iaroslav Postovalov - */ -public fun EjmlMatrixContext.inverse(a: Matrix): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert()) diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt index e0f15be83..70b82a3cb 100644 --- a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -1,7 +1,7 @@ package kscience.kmath.ejml import kscience.kmath.linear.DeterminantFeature -import kscience.kmath.linear.LUPDecompositionFeature +import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.MatrixFeature import kscience.kmath.linear.getFeature import org.ejml.dense.row.factory.DecompositionFactory_DDRM @@ -44,7 +44,7 @@ internal class EjmlMatrixTest { val w = EjmlMatrix(m) val det = w.getFeature>() ?: fail() assertEquals(m.determinant(), det.determinant) - val lup = w.getFeature>() ?: fail() + val lup = w.getFeature>() ?: fail() val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols()) .also { it.decompose(m.ddrm.copy()) } From 4635080317826427189bec6802e18922b60972eb Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 18 Jan 2021 21:33:53 +0300 Subject: [PATCH 5/8] Optimize RealMatrix dot operation --- examples/src/main/kotlin/kscience/kmath/structures/NDField.kt | 4 ++-- .../kotlin/kscience/kmath/linear/RealMatrixContext.kt | 2 +- .../kotlin/kscience/kmath/structures/NDStructure.kt | 4 ++++ settings.gradle.kts | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt index e53af0dee..778d811fd 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt @@ -45,14 +45,14 @@ fun main() { measureAndPrint("Specialized addition") { specializedField { var res: NDBuffer = one - repeat(n) { res += 1.0 } + repeat(n) { res += one } } } measureAndPrint("Nd4j specialized addition") { nd4jField { var res = one - repeat(n) { res += 1.0 as Number } + repeat(n) { res += one } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt index 772b20f3b..90e251c3a 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt @@ -61,7 +61,7 @@ public object RealMatrixContext : MatrixContext> { override fun multiply(a: Matrix, k: Number): BufferMatrix = - produce(a.rowNum, a.colNum) { i, j -> a.get(i, j) * k.toDouble() } + produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt index 08160adf4..5c5d28882 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt @@ -38,6 +38,7 @@ public interface NDStructure { */ public fun elements(): Sequence> + //force override equality and hash code public override fun equals(other: Any?): Boolean public override fun hashCode(): Int @@ -133,6 +134,9 @@ public interface MutableNDStructure : NDStructure { public operator fun set(index: IntArray, value: T) } +/** + * Transform a structure element-by element in place. + */ public inline fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T): Unit = elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } diff --git a/settings.gradle.kts b/settings.gradle.kts index da33fea59..025e4a3c6 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -8,7 +8,7 @@ pluginManagement { maven("https://dl.bintray.com/kotlin/kotlinx") } - val toolsVersion = "0.7.1" + val toolsVersion = "0.7.2-dev-2" val kotlinVersion = "1.4.21" plugins { From ad822271b3f552ff2e57d44031e9bb1fb4f00dc0 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 19 Jan 2021 20:25:26 +0700 Subject: [PATCH 6/8] Update changelog --- CHANGELOG.md | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e542d210c..0a2ae4109 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,27 +4,28 @@ ### Added - `fun` annotation for SAM interfaces in library - Explicit `public` visibility for all public APIs -- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140). +- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140) - Automatic README generation for features (#139) - Native support for `memory`, `core` and `dimensions` -- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136). +- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136) - A separate `Symbol` entity, which is used for global unbound symbol. - A `Symbol` indexing scope. - Basic optimization API for Commons-math. - Chi squared optimization for array-like data in CM - `Fitting` utility object in prob/stat -- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`. -- Coroutine-deterministic Monte-Carlo scope with a random number generator. -- Some minor utilities to `kmath-for-real`. +- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray` +- Coroutine-deterministic Monte-Carlo scope with a random number generator +- Some minor utilities to `kmath-for-real` - Generic operation result parameter to `MatrixContext` +- New `MatrixFeature` interfaces for matrix decompositions ### Changed -- Package changed from `scientifik` to `kscience.kmath`. -- Gradle version: 6.6 -> 6.7.1 +- Package changed from `scientifik` to `kscience.kmath` +- Gradle version: 6.6 -> 6.8 - Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`) -- `Polynomial` secondary constructor made function. -- Kotlin version: 1.3.72 -> 1.4.20 -- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library. +- `Polynomial` secondary constructor made function +- Kotlin version: 1.3.72 -> 1.4.21 +- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library - Full autodiff refactoring based on `Symbol` - `kmath-prob` renamed to `kmath-stat` - Grid generators moved to `kmath-for-real` From ab32cd95616adfb409c28cc0fb13754f69031081 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 19 Jan 2021 17:16:43 +0300 Subject: [PATCH 7/8] Numeric operations are decoupled from Ring --- CHANGELOG.md | 1 + build.gradle.kts | 2 +- .../kmath/benchmarks/LargeNDBenchmark.kt | 25 ++++ .../kmath/stat/DistributionBenchmark.kt | 3 +- .../kscience/kmath/structures/ComplexND.kt | 2 +- .../kscience/kmath/structures/NDField.kt | 8 +- .../kotlin/kscience/kmath/ast/MstAlgebra.kt | 31 +++-- .../DerivativeStructureExpression.kt | 7 +- .../FunctionalExpressionAlgebra.kt | 32 +++-- .../kmath/expressions/SimpleAutoDiff.kt | 4 +- .../kscience/kmath/linear/MatrixFeatures.kt | 4 +- .../kscience/kmath/operations/Algebra.kt | 117 +--------------- .../kscience/kmath/operations/BigInt.kt | 4 +- .../kscience/kmath/operations/Complex.kt | 4 +- .../kmath/operations/NumericAlgebra.kt | 125 ++++++++++++++++++ .../{NumberAlgebra.kt => numbers.kt} | 15 ++- .../kmath/structures/ComplexNDField.kt | 23 ++-- .../kmath/structures/RealBufferField.kt | 2 + .../kscience/kmath/structures/RealNDField.kt | 16 ++- .../kscience/kmath/structures/NDFieldTest.kt | 3 +- .../kscience/kmath/operations/BigNumbers.kt | 8 +- .../kotlin/kscience/kmath/ejml/EjmlMatrix.kt | 14 +- .../kscience/kmath/ejml/EjmlMatrixContext.kt | 6 + .../kscience.kmath.nd4j/Nd4jArrayAlgebra.kt | 46 ++++--- settings.gradle.kts | 4 +- 25 files changed, 300 insertions(+), 206 deletions(-) create mode 100644 examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt rename kmath-core/src/commonMain/kotlin/kscience/kmath/operations/{NumberAlgebra.kt => numbers.kt} (97%) diff --git a/CHANGELOG.md b/CHANGELOG.md index e542d210c..840733d92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ - Optimized dot product for buffer matrices moved to `kmath-for-real` - EjmlMatrix context is an object - Matrix LUP `inverse` renamed to `inverseWithLUP` +- `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it). ### Deprecated diff --git a/build.gradle.kts b/build.gradle.kts index 90a39c531..d171bd608 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -4,7 +4,7 @@ plugins { id("ru.mipt.npm.project") } -internal val kmathVersion: String by extra("0.2.0-dev-4") +internal val kmathVersion: String by extra("0.2.0-dev-5") internal val bintrayRepo: String by extra("kscience") internal val githubProject: String by extra("kmath") diff --git a/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt new file mode 100644 index 000000000..395fde619 --- /dev/null +++ b/examples/src/benchmarks/kotlin/kscience/kmath/benchmarks/LargeNDBenchmark.kt @@ -0,0 +1,25 @@ +package kscience.kmath.benchmarks + +import kscience.kmath.structures.NDField +import org.openjdk.jmh.annotations.Benchmark +import org.openjdk.jmh.annotations.Scope +import org.openjdk.jmh.annotations.State +import org.openjdk.jmh.infra.Blackhole +import kotlin.random.Random + +@State(Scope.Benchmark) +class LargeNDBenchmark { + val arraySize = 10000 + val RANDOM = Random(222) + val src1 = DoubleArray(arraySize) { RANDOM.nextDouble() } + val src2 = DoubleArray(arraySize) { RANDOM.nextDouble() } + val field = NDField.real(arraySize) + val kmathArray1 = field.produce { (a) -> src1[a] } + val kmathArray2 = field.produce { (a) -> src2[a] } + + @Benchmark + fun test10000(bh: Blackhole) { + bh.consume(field.add(kmathArray1, kmathArray2)) + } + +} \ No newline at end of file diff --git a/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt b/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt index ef554aeff..99d3cd504 100644 --- a/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt +++ b/examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt @@ -3,7 +3,6 @@ package kscience.kmath.commons.prob import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking -import kscience.kmath.chains.BlockingRealChain import kscience.kmath.stat.* import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler import org.apache.commons.rng.simple.RandomSource @@ -13,7 +12,7 @@ import java.time.Instant private fun runChain(): Duration { val generator = RandomGenerator.fromSource(RandomSource.MT, 123L) val normal = Distribution.normal(NormalSamplerMethod.Ziggurat) - val chain = normal.sample(generator) as BlockingRealChain + val chain = normal.sample(generator) val startTime = Instant.now() var sum = 0.0 diff --git a/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt b/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt index aa4b10ef2..b69590473 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/ComplexND.kt @@ -11,7 +11,7 @@ fun main() { val n = 1000 val realField = NDField.real(dim, dim) - val complexField = NDField.complex(dim, dim) + val complexField: ComplexNDField = NDField.complex(dim, dim) val realTime = measureTimeMillis { realField { diff --git a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt index 778d811fd..b5130c92b 100644 --- a/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt +++ b/examples/src/main/kotlin/kscience/kmath/structures/NDField.kt @@ -33,7 +33,7 @@ fun main() { measureAndPrint("Automatic field addition") { autoField { var res: NDBuffer = one - repeat(n) { res += number(1.0) } + repeat(n) { res += 1.0 } } } @@ -45,14 +45,14 @@ fun main() { measureAndPrint("Specialized addition") { specializedField { var res: NDBuffer = one - repeat(n) { res += one } + repeat(n) { res += 1.0 } } } measureAndPrint("Nd4j specialized addition") { nd4jField { var res = one - repeat(n) { res += one } + repeat(n) { res += 1.0 } } } @@ -73,7 +73,7 @@ fun main() { genericField { var res: NDBuffer = one repeat(n) { - res += one // couldn't avoid using `one` due to resolution ambiguity } + res += 1.0 // couldn't avoid using `one` due to resolution ambiguity } } } } diff --git a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt index 80b164a7c..eadbc85ee 100644 --- a/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt +++ b/kmath-ast/src/commonMain/kotlin/kscience/kmath/ast/MstAlgebra.kt @@ -1,5 +1,6 @@ package kscience.kmath.ast +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* /** @@ -25,8 +26,11 @@ public object MstSpace : Space, NumericAlgebra { public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value) public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value) public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b) - public override operator fun MST.unaryPlus(): MST.Unary = unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this) - public override operator fun MST.unaryMinus(): MST.Unary = unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this) + public override operator fun MST.unaryPlus(): MST.Unary = + unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this) + + public override operator fun MST.unaryMinus(): MST.Unary = + unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this) public override operator fun MST.minus(b: MST): MST.Binary = binaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this, b) @@ -44,7 +48,8 @@ public object MstSpace : Space, NumericAlgebra { /** * [Ring] over [MST] nodes. */ -public object MstRing : Ring, NumericAlgebra { +@OptIn(UnstableKMathAPI::class) +public object MstRing : Ring, RingWithNumbers { public override val zero: MST.Numeric get() = MstSpace.zero @@ -54,7 +59,9 @@ public object MstRing : Ring, NumericAlgebra { public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value) public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b) public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k) - public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + public override fun multiply(a: MST, b: MST): MST.Binary = + binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) + public override operator fun MST.unaryPlus(): MST.Unary = MstSpace { +this@unaryPlus } public override operator fun MST.unaryMinus(): MST.Unary = MstSpace { -this@unaryMinus } public override operator fun MST.minus(b: MST): MST.Binary = MstSpace { this@minus - b } @@ -69,7 +76,8 @@ public object MstRing : Ring, NumericAlgebra { /** * [Field] over [MST] nodes. */ -public object MstField : Field { +@OptIn(UnstableKMathAPI::class) +public object MstField : Field, RingWithNumbers { public override val zero: MST.Numeric get() = MstRing.zero @@ -81,7 +89,9 @@ public object MstField : Field { public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k) public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) - public override fun divide(a: MST, b: MST): MST.Binary = binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + public override fun divide(a: MST, b: MST): MST.Binary = + binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) + public override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } public override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } public override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } @@ -89,13 +99,14 @@ public object MstField : Field { public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstRing.binaryOperationFunction(operation) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperationFunction(operation) + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstRing.unaryOperationFunction(operation) } /** * [ExtendedField] over [MST] nodes. */ -public object MstExtendedField : ExtendedField { +public object MstExtendedField : ExtendedField, NumericAlgebra { public override val zero: MST.Numeric get() = MstField.zero @@ -103,6 +114,7 @@ public object MstExtendedField : ExtendedField { get() = MstField.one public override fun symbol(value: String): MST.Symbolic = MstField.symbol(value) + public override fun number(value: Number): MST.Numeric = MstRing.number(value) public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg) public override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg) @@ -132,5 +144,6 @@ public object MstExtendedField : ExtendedField { public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = MstField.binaryOperationFunction(operation) - public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperationFunction(operation) + public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = + MstField.unaryOperationFunction(operation) } diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt index 345babe8b..2912ddc4c 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/expressions/DerivativeStructureExpression.kt @@ -1,7 +1,9 @@ package kscience.kmath.commons.expressions import kscience.kmath.expressions.* +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.ExtendedField +import kscience.kmath.operations.RingWithNumbers import org.apache.commons.math3.analysis.differentiation.DerivativeStructure /** @@ -10,15 +12,18 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure * @property order The derivation order. * @property bindings The map of bindings values. All bindings are considered free parameters */ +@OptIn(UnstableKMathAPI::class) public class DerivativeStructureField( public val order: Int, bindings: Map, -) : ExtendedField, ExpressionAlgebra { +) : ExtendedField, ExpressionAlgebra, RingWithNumbers { public val numberOfVariables: Int = bindings.size public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) } + override fun number(value: Number): DerivativeStructure = const(value.toDouble()) + /** * A class that implements both [DerivativeStructure] and a [Symbol] */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index afbaf16e1..880a4e34c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -7,8 +7,9 @@ import kscience.kmath.operations.* * * @param algebra The algebra to provide for Expressions built. */ -public abstract class FunctionalExpressionAlgebra>(public val algebra: A) : - ExpressionAlgebra> { +public abstract class FunctionalExpressionAlgebra>( + public val algebra: A, +) : ExpressionAlgebra> { /** * Builds an Expression of constant expression which does not depend on arguments. */ @@ -42,8 +43,9 @@ public abstract class FunctionalExpressionAlgebra>(public val /** * A context class for [Expression] construction for [Space] algebras. */ -public open class FunctionalExpressionSpace>(algebra: A) : - FunctionalExpressionAlgebra(algebra), Space> { +public open class FunctionalExpressionSpace>( + algebra: A, +) : FunctionalExpressionAlgebra(algebra), Space> { public override val zero: Expression get() = const(algebra.zero) /** @@ -71,8 +73,9 @@ public open class FunctionalExpressionSpace>(algebra: A) : super.binaryOperationFunction(operation) } -public open class FunctionalExpressionRing(algebra: A) : FunctionalExpressionSpace(algebra), - Ring> where A : Ring, A : NumericAlgebra { +public open class FunctionalExpressionRing>( + algebra: A, +) : FunctionalExpressionSpace(algebra), Ring> { public override val one: Expression get() = const(algebra.one) @@ -92,9 +95,8 @@ public open class FunctionalExpressionRing(algebra: A) : FunctionalExpress super.binaryOperationFunction(operation) } -public open class FunctionalExpressionField(algebra: A) : - FunctionalExpressionRing(algebra), Field> - where A : Field, A : NumericAlgebra { +public open class FunctionalExpressionField>(algebra: A) : + FunctionalExpressionRing(algebra), Field> { /** * Builds an Expression of division an expression by another one. */ @@ -111,9 +113,12 @@ public open class FunctionalExpressionField(algebra: A) : super.binaryOperationFunction(operation) } -public open class FunctionalExpressionExtendedField(algebra: A) : - FunctionalExpressionField(algebra), - ExtendedField> where A : ExtendedField, A : NumericAlgebra { +public open class FunctionalExpressionExtendedField>( + algebra: A, +) : FunctionalExpressionField(algebra), ExtendedField> { + + override fun number(value: Number): Expression = const(algebra.number(value)) + public override fun sin(arg: Expression): Expression = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg) @@ -135,7 +140,8 @@ public open class FunctionalExpressionExtendedField(algebra: A) : public override fun exp(arg: Expression): Expression = unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg) - public override fun ln(arg: Expression): Expression = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) + public override fun ln(arg: Expression): Expression = + unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg) public override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = super.unaryOperationFunction(operation) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt index e8a894d23..0621e82bd 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -1,6 +1,7 @@ package kscience.kmath.expressions import kscience.kmath.linear.Point +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* import kscience.kmath.structures.asBuffer import kotlin.contracts.InvocationKind @@ -79,10 +80,11 @@ public fun > F.simpleAutoDiff( /** * Represents field in context of which functions can be derived. */ +@OptIn(UnstableKMathAPI::class) public open class SimpleAutoDiffField>( public val context: F, bindings: Map, -) : Field>, ExpressionAlgebra> { +) : Field>, ExpressionAlgebra>, RingWithNumbers> { public override val zero: AutoDiffValue get() = const(context.zero) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt index 767b58eba..1f93309a6 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt @@ -1,5 +1,7 @@ package kscience.kmath.linear +import kscience.kmath.structures.Matrix + /** * A marker interface representing some properties of matrices or additional transformations of them. Features are used * to optimize matrix operations performance in some cases or retrieve the APIs. @@ -30,7 +32,7 @@ public interface InverseMatrixFeature : MatrixFeature { /** * The inverse matrix of the matrix that owns this feature. */ - public val inverse: FeaturedMatrix + public val inverse: Matrix } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt index 24b417860..2bafd377e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Algebra.kt @@ -88,85 +88,6 @@ public interface Algebra { public fun binaryOperation(operation: String, left: T, right: T): T = binaryOperationFunction(operation)(left, right) } -/** - * An algebraic structure where elements can have numeric representation. - * - * @param T the type of element of this structure. - */ -public interface NumericAlgebra : Algebra { - /** - * Wraps a number to [T] object. - * - * @param value the number to wrap. - * @return an object. - */ - public fun number(value: Number): T - - /** - * Dynamically dispatches a binary operation with the certain name with numeric first argument. - * - * This function must follow two properties: - * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with the other [leftSideNumberOperation] overload: - * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`. - * - * @param operation the name of operation. - * @return an operation. - */ - public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T = - { l, r -> binaryOperationFunction(operation)(number(l), r) } - - /** - * Dynamically invokes a binary operation with the certain name with numeric first argument. - * - * This function must follow two properties: - * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with second [leftSideNumberOperation] overload: - * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. - * - * @param operation the name of operation. - * @param left the first argument of operation. - * @param right the second argument of operation. - * @return a result of operation. - */ - public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = - leftSideNumberOperationFunction(operation)(left, right) - - /** - * Dynamically dispatches a binary operation with the certain name with numeric first argument. - * - * This function must follow two properties: - * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: - * i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. - * - * @param operation the name of operation. - * @return an operation. - */ - public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = - { l, r -> binaryOperationFunction(operation)(l, number(r)) } - - /** - * Dynamically invokes a binary operation with the certain name with numeric second argument. - * - * This function must follow two properties: - * - * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. - * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: - * i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`. - * - * @param operation the name of operation. - * @param left the first argument of operation. - * @param right the second argument of operation. - * @return a result of operation. - */ - public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = - rightSideNumberOperationFunction(operation)(left, right) -} - /** * Call a block with an [Algebra] as receiver. */ @@ -341,47 +262,11 @@ public interface RingOperations : SpaceOperations { * * @param T the type of element of this ring. */ -public interface Ring : Space, RingOperations, NumericAlgebra { +public interface Ring : Space, RingOperations { /** * neutral operation for multiplication */ public val one: T - - public override fun number(value: Number): T = one * value.toDouble() - - /** - * Addition of element and scalar. - * - * @receiver the addend. - * @param b the augend. - */ - public operator fun T.plus(b: Number): T = this + number(b) - - /** - * Addition of scalar and element. - * - * @receiver the addend. - * @param b the augend. - */ - public operator fun Number.plus(b: T): T = b + this - - /** - * Subtraction of element from number. - * - * @receiver the minuend. - * @param b the subtrahend. - * @receiver the difference. - */ - public operator fun T.minus(b: Number): T = this - number(b) - - /** - * Subtraction of number from element. - * - * @receiver the minuend. - * @param b the subtrahend. - * @receiver the difference. - */ - public operator fun Number.minus(b: T): T = -b + this } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt index 20f289596..0be72e80c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt @@ -1,5 +1,6 @@ package kscience.kmath.operations +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.BigInt.Companion.BASE import kscience.kmath.operations.BigInt.Companion.BASE_SIZE import kscience.kmath.structures.* @@ -16,7 +17,8 @@ public typealias TBase = ULong * * @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai) */ -public object BigIntField : Field { +@OptIn(UnstableKMathAPI::class) +public object BigIntField : Field, RingWithNumbers { override val zero: BigInt = BigInt.ZERO override val one: BigInt = BigInt.ONE diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index 703931c7c..5695e6696 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -41,7 +41,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0) /** * A field of [Complex]. */ -public object ComplexField : ExtendedField, Norm { +public object ComplexField : ExtendedField, Norm, RingWithNumbers { override val zero: Complex = 0.0.toComplex() override val one: Complex = 1.0.toComplex() @@ -156,7 +156,7 @@ public object ComplexField : ExtendedField, Norm { override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) - override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) + override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt new file mode 100644 index 000000000..26f93fae8 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumericAlgebra.kt @@ -0,0 +1,125 @@ +package kscience.kmath.operations + +import kscience.kmath.misc.UnstableKMathAPI + +/** + * An algebraic structure where elements can have numeric representation. + * + * @param T the type of element of this structure. + */ +public interface NumericAlgebra : Algebra { + /** + * Wraps a number to [T] object. + * + * @param value the number to wrap. + * @return an object. + */ + public fun number(value: Number): T + + /** + * Dynamically dispatches a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [leftSideNumberOperation] overload: + * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T = + { l, r -> binaryOperationFunction(operation)(number(l), r) } + + /** + * Dynamically invokes a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with second [leftSideNumberOperation] overload: + * i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun leftSideNumberOperation(operation: String, left: Number, right: T): T = + leftSideNumberOperationFunction(operation)(left, right) + + /** + * Dynamically dispatches a binary operation with the certain name with numeric first argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: + * i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @return an operation. + */ + public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T = + { l, r -> binaryOperationFunction(operation)(l, number(r)) } + + /** + * Dynamically invokes a binary operation with the certain name with numeric second argument. + * + * This function must follow two properties: + * + * 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException]. + * 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload: + * i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`. + * + * @param operation the name of operation. + * @param left the first argument of operation. + * @param right the second argument of operation. + * @return a result of operation. + */ + public fun rightSideNumberOperation(operation: String, left: T, right: Number): T = + rightSideNumberOperationFunction(operation)(left, right) +} + +/** + * A combination of [NumericAlgebra] and [Ring] that adds intrinsic simple operations on numbers like `T+1` + * TODO to be removed and replaced by extensions after multiple receivers are there + */ +@UnstableKMathAPI +public interface RingWithNumbers: Ring, NumericAlgebra{ + public override fun number(value: Number): T = one * value + + /** + * Addition of element and scalar. + * + * @receiver the addend. + * @param b the augend. + */ + public operator fun T.plus(b: Number): T = this + number(b) + + /** + * Addition of scalar and element. + * + * @receiver the addend. + * @param b the augend. + */ + public operator fun Number.plus(b: T): T = b + this + + /** + * Subtraction of element from number. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + public operator fun T.minus(b: Number): T = this - number(b) + + /** + * Subtraction of number from element. + * + * @receiver the minuend. + * @param b the subtrahend. + * @receiver the difference. + */ + public operator fun Number.minus(b: T): T = -b + this +} \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt similarity index 97% rename from kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt index a7b73d5e0..de3818aa6 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt @@ -37,7 +37,7 @@ public interface ExtendedFieldOperations : /** * Advanced Number-like field that implements basic operations. */ -public interface ExtendedField : ExtendedFieldOperations, Field { +public interface ExtendedField : ExtendedFieldOperations, Field, NumericAlgebra { public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2 public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2 public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) @@ -80,6 +80,8 @@ public object RealField : ExtendedField, Norm { public override val one: Double get() = 1.0 + override fun number(value: Number): Double = value.toDouble() + public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double = when (operation) { PowerOperations.POW_OPERATION -> ::power @@ -131,10 +133,13 @@ public object FloatField : ExtendedField, Norm { public override val one: Float get() = 1.0f - public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = when (operation) { - PowerOperations.POW_OPERATION -> ::power - else -> super.binaryOperationFunction(operation) - } + override fun number(value: Number): Float = value.toFloat() + + public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = + when (operation) { + PowerOperations.POW_OPERATION -> ::power + else -> super.binaryOperationFunction(operation) + } public override inline fun add(a: Float, b: Float): Float = a + b public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt index f1f1074e5..6de69cabe 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/ComplexNDField.kt @@ -1,9 +1,7 @@ package kscience.kmath.structures -import kscience.kmath.operations.Complex -import kscience.kmath.operations.ComplexField -import kscience.kmath.operations.FieldElement -import kscience.kmath.operations.complex +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.operations.* import kotlin.contracts.InvocationKind import kotlin.contracts.contract @@ -12,15 +10,22 @@ public typealias ComplexNDElement = BufferedNDFieldElement, - ExtendedNDField> { + ExtendedNDField>, + RingWithNumbers>{ override val strides: Strides = DefaultStrides(shape) override val elementContext: ComplexField get() = ComplexField override val zero: ComplexNDElement by lazy { produce { zero } } override val one: ComplexNDElement by lazy { produce { one } } + override fun number(value: Number): NDBuffer { + val c = value.toComplex() + return produce { c } + } + public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer = Buffer.complex(size) { initializer(it) } @@ -29,7 +34,7 @@ public class ComplexNDField(override val shape: IntArray) : */ override fun map( arg: NDBuffer, - transform: ComplexField.(Complex) -> Complex + transform: ComplexField.(Complex) -> Complex, ): ComplexNDElement { check(arg) val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) } @@ -43,7 +48,7 @@ public class ComplexNDField(override val shape: IntArray) : override fun mapIndexed( arg: NDBuffer, - transform: ComplexField.(index: IntArray, Complex) -> Complex + transform: ComplexField.(index: IntArray, Complex) -> Complex, ): ComplexNDElement { check(arg) @@ -60,7 +65,7 @@ public class ComplexNDField(override val shape: IntArray) : override fun combine( a: NDBuffer, b: NDBuffer, - transform: ComplexField.(Complex, Complex) -> Complex + transform: ComplexField.(Complex, Complex) -> Complex, ): ComplexNDElement { check(a, b) @@ -141,7 +146,7 @@ public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = Comple public fun NDElement.Companion.complex( vararg shape: Int, - initializer: ComplexField.(IntArray) -> Complex + initializer: ComplexField.(IntArray) -> Complex, ): ComplexNDElement = NDField.complex(*shape).produce(initializer) /** diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt index f7b2ee31e..3f4d15c4d 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealBufferField.kt @@ -150,6 +150,8 @@ public class RealBufferField(public val size: Int) : ExtendedField by lazy { RealBuffer(size) { 0.0 } } public override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } + override fun number(value: Number): Buffer = RealBuffer(size) { value.toDouble() } + public override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt index ed28fb9f2..3eb1dc4ca 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/RealNDField.kt @@ -1,13 +1,17 @@ package kscience.kmath.structures +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.FieldElement import kscience.kmath.operations.RealField +import kscience.kmath.operations.RingWithNumbers public typealias RealNDElement = BufferedNDFieldElement +@OptIn(UnstableKMathAPI::class) public class RealNDField(override val shape: IntArray) : BufferedNDField, - ExtendedNDField> { + ExtendedNDField>, + RingWithNumbers>{ override val strides: Strides = DefaultStrides(shape) @@ -15,7 +19,12 @@ public class RealNDField(override val shape: IntArray) : override val zero: RealNDElement by lazy { produce { zero } } override val one: RealNDElement by lazy { produce { one } } - public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = + override fun number(value: Number): NDBuffer { + val d = value.toDouble() + return produce { d } + } + + private inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = RealBuffer(DoubleArray(size) { initializer(it) }) /** @@ -59,7 +68,8 @@ public class RealNDField(override val shape: IntArray) : check(a, b) return BufferedNDFieldElement( this, - buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) + buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) } + ) } override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = diff --git a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt index b763ec7de..1129a8a36 100644 --- a/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt +++ b/kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt @@ -1,14 +1,13 @@ package kscience.kmath.structures import kscience.kmath.operations.internal.FieldVerifier -import kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals internal class NDFieldTest { @Test fun verify() { - (NDField.real(12, 32)) { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } + NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) } } @Test diff --git a/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt b/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt index 2f0978237..9bd6a9fc4 100644 --- a/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt +++ b/kmath-core/src/jvmMain/kotlin/kscience/kmath/operations/BigNumbers.kt @@ -7,7 +7,7 @@ import java.math.MathContext /** * A field over [BigInteger]. */ -public object JBigIntegerField : Field { +public object JBigIntegerField : Field, NumericAlgebra { public override val zero: BigInteger get() = BigInteger.ZERO @@ -28,9 +28,9 @@ public object JBigIntegerField : Field { * * @property mathContext the [MathContext] to use. */ -public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) : - Field, - PowerOperations { +public abstract class JBigDecimalFieldBase internal constructor( + private val mathContext: MathContext = MathContext.DECIMAL64, +) : Field, PowerOperations, NumericAlgebra { public override val zero: BigDecimal get() = BigDecimal.ZERO diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt index 5b7d0a01b..a7d571b58 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -12,8 +12,10 @@ import org.ejml.simple.SimpleMatrix * @property origin the underlying [SimpleMatrix]. * @author Iaroslav Postovalov */ -public class EjmlMatrix(public val origin: SimpleMatrix, features: Set = emptySet()) : - FeaturedMatrix { +public class EjmlMatrix( + public val origin: SimpleMatrix, + features: Set = emptySet() +) : FeaturedMatrix { public override val rowNum: Int get() = origin.numRows() @@ -88,11 +90,7 @@ public class EjmlMatrix(public val origin: SimpleMatrix, features: Set ?: return false) } - public override fun hashCode(): Int { - var result = origin.hashCode() - result = 31 * result + features.hashCode() - return result - } + public override fun hashCode(): Int = origin.hashCode() - public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)" + public override fun toString(): String = "EjmlMatrix($origin)" } diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index f8791b72c..c8789a4a7 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -1,7 +1,9 @@ package kscience.kmath.ejml +import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.Point +import kscience.kmath.linear.getFeature import kscience.kmath.structures.Matrix import org.ejml.simple.SimpleMatrix @@ -77,3 +79,7 @@ public fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMa */ public fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) + +public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature>()!!.inverse as EjmlMatrix + +public fun EjmlMatrixContext.inverse(matrix: Matrix): Matrix = matrix.toEjml().inverted() \ No newline at end of file diff --git a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt index a8c874fc3..db2a44861 100644 --- a/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/kscience.kmath.nd4j/Nd4jArrayAlgebra.kt @@ -1,5 +1,6 @@ package kscience.kmath.nd4j +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* import kscience.kmath.structures.NDAlgebra import kscience.kmath.structures.NDField @@ -35,7 +36,7 @@ public interface Nd4jArrayAlgebra : NDAlgebra> public override fun mapIndexed( arg: Nd4jArrayStructure, - transform: C.(index: IntArray, T) -> T + transform: C.(index: IntArray, T) -> T, ): Nd4jArrayStructure { check(arg) val new = Nd4j.create(*shape).wrap() @@ -46,7 +47,7 @@ public interface Nd4jArrayAlgebra : NDAlgebra> public override fun combine( a: Nd4jArrayStructure, b: Nd4jArrayStructure, - transform: C.(T, T) -> T + transform: C.(T, T) -> T, ): Nd4jArrayStructure { check(a, b) val new = Nd4j.create(*shape).wrap() @@ -61,8 +62,8 @@ public interface Nd4jArrayAlgebra : NDAlgebra> * @param T the type of the element contained in ND structure. * @param S the type of space of structure elements. */ -public interface Nd4jArraySpace : NDSpace>, - Nd4jArrayAlgebra where S : Space { +public interface Nd4jArraySpace> : NDSpace>, Nd4jArrayAlgebra { + public override val zero: Nd4jArrayStructure get() = Nd4j.zeros(*shape).wrap() @@ -103,7 +104,9 @@ public interface Nd4jArraySpace : NDSpace>, * @param T the type of the element contained in ND structure. * @param R the type of ring of structure elements. */ -public interface Nd4jArrayRing : NDRing>, Nd4jArraySpace where R : Ring { +@OptIn(UnstableKMathAPI::class) +public interface Nd4jArrayRing> : NDRing>, Nd4jArraySpace { + public override val one: Nd4jArrayStructure get() = Nd4j.ones(*shape).wrap() @@ -111,21 +114,21 @@ public interface Nd4jArrayRing : NDRing>, Nd4j check(a, b) return a.ndArray.mul(b.ndArray).wrap() } - - public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { - check(this) - return ndArray.sub(b).wrap() - } - - public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { - check(this) - return ndArray.add(b).wrap() - } - - public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { - check(b) - return b.ndArray.rsub(this).wrap() - } +// +// public override operator fun Nd4jArrayStructure.minus(b: Number): Nd4jArrayStructure { +// check(this) +// return ndArray.sub(b).wrap() +// } +// +// public override operator fun Nd4jArrayStructure.plus(b: Number): Nd4jArrayStructure { +// check(this) +// return ndArray.add(b).wrap() +// } +// +// public override operator fun Number.minus(b: Nd4jArrayStructure): Nd4jArrayStructure { +// check(b) +// return b.ndArray.rsub(this).wrap() +// } public companion object { private val intNd4jArrayRingCache: ThreadLocal> = @@ -165,7 +168,8 @@ public interface Nd4jArrayRing : NDRing>, Nd4j * @param N the type of ND structure. * @param F the type field of structure elements. */ -public interface Nd4jArrayField : NDField>, Nd4jArrayRing where F : Field { +public interface Nd4jArrayField> : NDField>, Nd4jArrayRing { + public override fun divide(a: Nd4jArrayStructure, b: Nd4jArrayStructure): Nd4jArrayStructure { check(a, b) return a.ndArray.div(b.ndArray).wrap() diff --git a/settings.gradle.kts b/settings.gradle.kts index 025e4a3c6..a1ea40148 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -8,8 +8,8 @@ pluginManagement { maven("https://dl.bintray.com/kotlin/kotlinx") } - val toolsVersion = "0.7.2-dev-2" - val kotlinVersion = "1.4.21" + val toolsVersion = "0.7.3-1.4.30-RC" + val kotlinVersion = "1.4.30-RC" plugins { id("kotlinx.benchmark") version "0.2.0-dev-20" From 4c256a9f14f4db251cb2797f8ddcfb1b88ab8dac Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Tue, 19 Jan 2021 19:32:13 +0300 Subject: [PATCH 8/8] Features refactoring. --- CHANGELOG.md | 1 + .../kscience/kmath/commons/linear/CMMatrix.kt | 50 +++++----- .../kscience/kmath/linear/BufferMatrix.kt | 26 +++--- ...UPDecomposition.kt => LupDecomposition.kt} | 58 +++++++----- .../kscience/kmath/linear/MatrixBuilder.kt | 15 ++- .../kscience/kmath/linear/MatrixContext.kt | 2 - .../kscience/kmath/linear/MatrixFeatures.kt | 26 +++--- .../{FeaturedMatrix.kt => MatrixWrapper.kt} | 78 +++++++++------- .../kmath/linear/RealMatrixContext.kt | 20 +--- .../kscience/kmath/linear/VirtualMatrix.kt | 31 +------ .../kscience/kmath/operations/numbers.kt | 16 +++- .../kscience/kmath/structures/NDStructure.kt | 11 +++ .../kscience/kmath/structures/Structure2D.kt | 9 +- .../kscience/kmath/dimensions/Wrappers.kt | 9 +- .../kotlin/kscience/kmath/ejml/EjmlMatrix.kt | 91 ++++++++----------- .../kscience/kmath/ejml/EjmlMatrixContext.kt | 21 +++-- .../kscience/kmath/ejml/EjmlMatrixTest.kt | 5 +- .../kotlin/kscience/kmath/real/RealMatrix.kt | 8 +- .../kaceince/kmath/real/RealMatrixTest.kt | 3 +- 19 files changed, 236 insertions(+), 244 deletions(-) rename kmath-core/src/commonMain/kotlin/kscience/kmath/linear/{LUPDecomposition.kt => LupDecomposition.kt} (78%) rename kmath-core/src/commonMain/kotlin/kscience/kmath/linear/{FeaturedMatrix.kt => MatrixWrapper.kt} (50%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 840733d92..27d4dd1aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ - EjmlMatrix context is an object - Matrix LUP `inverse` renamed to `inverseWithLUP` - `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it). +- Features moved to NDStructure and became transparent. ### Deprecated diff --git a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt index 49888f8d6..48b6e0ef1 100644 --- a/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/main/kotlin/kscience/kmath/commons/linear/CMMatrix.kt @@ -1,42 +1,28 @@ package kscience.kmath.commons.linear -import kscience.kmath.linear.* +import kscience.kmath.linear.DiagonalFeature +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.MatrixWrapper +import kscience.kmath.linear.Point +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix -import kscience.kmath.structures.NDStructure import org.apache.commons.math3.linear.* +import kotlin.reflect.KClass +import kotlin.reflect.cast -public class CMMatrix(public val origin: RealMatrix, features: Set? = null) : FeaturedMatrix { +public inline class CMMatrix(public val origin: RealMatrix) : Matrix { public override val rowNum: Int get() = origin.rowDimension public override val colNum: Int get() = origin.columnDimension - public override val features: Set = features ?: sequence { - if (origin is DiagonalMatrix) yield(DiagonalFeature) - }.toHashSet() - - public override fun suggestFeature(vararg features: MatrixFeature): CMMatrix = - CMMatrix(origin, this.features + features) + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = when (type) { + DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null + else -> null + }?.let { type.cast(it) } public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j) - - public override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - public override fun hashCode(): Int { - var result = origin.hashCode() - result = 31 * result + features.hashCode() - return result - } } -//TODO move inside context -public fun Matrix.toCM(): CMMatrix = if (this is CMMatrix) { - this -} else { - //TODO add feature analysis - val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } - CMMatrix(Array2DRowRealMatrix(array)) -} public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this) @@ -61,6 +47,16 @@ public object CMMatrixContext : MatrixContext { return CMMatrix(Array2DRowRealMatrix(array)) } + public fun Matrix.toCM(): CMMatrix = when { + this is CMMatrix -> this + this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix + else -> { + //TODO add feature analysis + val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } } + CMMatrix(Array2DRowRealMatrix(array)) + } + } + public override fun Matrix.dot(other: Matrix): CMMatrix = CMMatrix(toCM().origin.multiply(other.toCM().origin)) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt index 402161207..80460baca 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/BufferMatrix.kt @@ -1,10 +1,7 @@ package kscience.kmath.linear import kscience.kmath.operations.Ring -import kscience.kmath.structures.Buffer -import kscience.kmath.structures.BufferFactory -import kscience.kmath.structures.NDStructure -import kscience.kmath.structures.asSequence +import kscience.kmath.structures.* /** * Basic implementation of Matrix space based on [NDStructure] @@ -27,8 +24,7 @@ public class BufferMatrix( public override val rowNum: Int, public override val colNum: Int, public val buffer: Buffer, - public override val features: Set = emptySet(), -) : FeaturedMatrix { +) : Matrix { init { require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" } @@ -36,9 +32,6 @@ public class BufferMatrix( override val shape: IntArray get() = intArrayOf(rowNum, colNum) - public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix = - BufferMatrix(rowNum, colNum, buffer, this.features + features) - public override operator fun get(index: IntArray): T = get(index[0], index[1]) public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j] @@ -50,23 +43,26 @@ public class BufferMatrix( if (this === other) return true return when (other) { - is NDStructure<*> -> return NDStructure.equals(this, other) + is NDStructure<*> -> NDStructure.equals(this, other) else -> false } } - public override fun hashCode(): Int { - var result = buffer.hashCode() - result = 31 * result + features.hashCode() + override fun hashCode(): Int { + var result = rowNum + result = 31 * result + colNum + result = 31 * result + buffer.hashCode() return result } public override fun toString(): String { return if (rowNum <= 5 && colNum <= 5) - "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" + + "Matrix(rowsNum = $rowNum, colNum = $colNum)\n" + rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer -> buffer.asSequence().joinToString(separator = "\t") { it.toString() } } - else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)" + else "Matrix(rowsNum = $rowNum, colNum = $colNum)" } + + } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt similarity index 78% rename from kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt index 75acaf8a1..f4f998da2 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LUPDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/LupDecomposition.kt @@ -1,13 +1,14 @@ package kscience.kmath.linear +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.* import kscience.kmath.structures.* /** * Common implementation of [LupDecompositionFeature]. */ -public class LUPDecomposition( - public val context: MatrixContext>, +public class LupDecomposition( + public val context: MatrixContext>, public val elementContext: Field, public val lu: Matrix, public val pivot: IntArray, @@ -18,13 +19,13 @@ public class LUPDecomposition( * * L is a lower-triangular matrix with [Ring.one] in diagonal */ - override val l: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j -> + override val l: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> when { j < i -> lu[i, j] j == i -> elementContext.one else -> elementContext.zero } - } + } + LFeature /** @@ -32,9 +33,9 @@ public class LUPDecomposition( * * U is an upper-triangular matrix including the diagonal */ - override val u: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j -> + override val u: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j >= i) lu[i, j] else elementContext.zero - } + } + UFeature /** * Returns the P rows permutation matrix. @@ -42,7 +43,7 @@ public class LUPDecomposition( * P is a sparse matrix with exactly one element set to [Ring.one] in * each row and each column, all other elements being set to [Ring.zero]. */ - override val p: FeaturedMatrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> + override val p: Matrix = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j -> if (j == pivot[i]) elementContext.one else elementContext.zero } @@ -63,12 +64,12 @@ internal fun , F : Field> GenericMatrixContext.abs /** * Create a lup decomposition of generic matrix. */ -public fun > MatrixContext>.lup( +public fun > MatrixContext>.lup( factory: MutableBufferFactory, elementContext: Field, matrix: Matrix, checkSingular: (T) -> Boolean, -): LUPDecomposition { +): LupDecomposition { require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" } val m = matrix.colNum val pivot = IntArray(matrix.rowNum) @@ -137,23 +138,23 @@ public fun > MatrixContext>.lup( for (row in col + 1 until m) lu[row, col] /= luDiag } - return LUPDecomposition(this@lup, elementContext, lu.collect(), pivot, even) + return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even) } } } -public inline fun , F : Field> GenericMatrixContext>.lup( +public inline fun , F : Field> GenericMatrixContext>.lup( matrix: Matrix, noinline checkSingular: (T) -> Boolean, -): LUPDecomposition = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular) +): LupDecomposition = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular) -public fun MatrixContext>.lup(matrix: Matrix): LUPDecomposition = +public fun MatrixContext>.lup(matrix: Matrix): LupDecomposition = lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 } -public fun LUPDecomposition.solveWithLUP( +public fun LupDecomposition.solveWithLUP( factory: MutableBufferFactory, - matrix: Matrix -): FeaturedMatrix { + matrix: Matrix, +): Matrix { require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" } BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run { @@ -198,25 +199,40 @@ public fun LUPDecomposition.solveWithLUP( } } -public inline fun LUPDecomposition.solveWithLUP(matrix: Matrix): Matrix = +public inline fun LupDecomposition.solveWithLUP(matrix: Matrix): Matrix = solveWithLUP(MutableBuffer.Companion::auto, matrix) /** * Solve a linear equation **a*x = b** using LUP decomposition */ -public inline fun , F : Field> GenericMatrixContext>.solveWithLUP( +@OptIn(UnstableKMathAPI::class) +public inline fun , F : Field> GenericMatrixContext>.solveWithLUP( a: Matrix, b: Matrix, noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, noinline checkSingular: (T) -> Boolean, -): FeaturedMatrix { +): Matrix { // Use existing decomposition if it is provided by matrix val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular) return decomposition.solveWithLUP(bufferFactory, b) } -public inline fun , F : Field> GenericMatrixContext>.inverseWithLUP( +public inline fun , F : Field> GenericMatrixContext>.inverseWithLUP( matrix: Matrix, noinline bufferFactory: MutableBufferFactory = MutableBuffer.Companion::auto, noinline checkSingular: (T) -> Boolean, -): FeaturedMatrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) +): Matrix = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular) + + +public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): Matrix { + // Use existing decomposition if it is provided by matrix + val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real + val decomposition: LupDecomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 } + return decomposition.solveWithLUP(bufferFactory, b) +} + +/** + * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. + */ +public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): Matrix = + solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum)) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt index 91c1ec824..c0c209248 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixBuilder.kt @@ -1,12 +1,9 @@ package kscience.kmath.linear -import kscience.kmath.structures.Buffer -import kscience.kmath.structures.BufferFactory -import kscience.kmath.structures.Structure2D -import kscience.kmath.structures.asBuffer +import kscience.kmath.structures.* public class MatrixBuilder(public val rows: Int, public val columns: Int) { - public operator fun invoke(vararg elements: T): FeaturedMatrix { + public operator fun invoke(vararg elements: T): Matrix { require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" } val buffer = elements.asBuffer() return BufferMatrix(rows, columns, buffer) @@ -17,7 +14,7 @@ public class MatrixBuilder(public val rows: Int, public val columns: Int) { public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns) -public fun Structure2D.Companion.row(vararg values: T): FeaturedMatrix { +public fun Structure2D.Companion.row(vararg values: T): Matrix { val buffer = values.asBuffer() return BufferMatrix(1, values.size, buffer) } @@ -26,12 +23,12 @@ public inline fun Structure2D.Companion.row( size: Int, factory: BufferFactory = Buffer.Companion::auto, noinline builder: (Int) -> T -): FeaturedMatrix { +): Matrix { val buffer = factory(size, builder) return BufferMatrix(1, size, buffer) } -public fun Structure2D.Companion.column(vararg values: T): FeaturedMatrix { +public fun Structure2D.Companion.column(vararg values: T): Matrix { val buffer = values.asBuffer() return BufferMatrix(values.size, 1, buffer) } @@ -40,7 +37,7 @@ public inline fun Structure2D.Companion.column( size: Int, factory: BufferFactory = Buffer.Companion::auto, noinline builder: (Int) -> T -): FeaturedMatrix { +): Matrix { val buffer = factory(size, builder) return BufferMatrix(size, 1, buffer) } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index 8c28a240f..59a41f840 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -133,8 +133,6 @@ public interface GenericMatrixContext, out M : Matrix> : public override fun multiply(a: Matrix, k: Number): M = produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } } - public operator fun Number.times(matrix: FeaturedMatrix): M = multiply(matrix, this) - public override operator fun Matrix.times(value: T): M = produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt index 1f93309a6..e61feec6c 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixFeatures.kt @@ -11,17 +11,19 @@ public interface MatrixFeature /** * Matrices with this feature are considered to have only diagonal non-null elements. */ -public object DiagonalFeature : MatrixFeature +public interface DiagonalFeature : MatrixFeature{ + public companion object: DiagonalFeature +} /** * Matrices with this feature have all zero elements. */ -public object ZeroFeature : MatrixFeature +public object ZeroFeature : DiagonalFeature /** * Matrices with this feature have unit elements on diagonal and zero elements in all other places. */ -public object UnitFeature : MatrixFeature +public object UnitFeature : DiagonalFeature /** * Matrices with this feature can be inverted: [inverse] = `a`-1 where `a` is the owning matrix. @@ -76,17 +78,17 @@ public interface LupDecompositionFeature : MatrixFeature { /** * The lower triangular matrix in this decomposition. It may have [LFeature]. */ - public val l: FeaturedMatrix + public val l: Matrix /** * The upper triangular matrix in this decomposition. It may have [UFeature]. */ - public val u: FeaturedMatrix + public val u: Matrix /** * The permutation matrix in this decomposition. */ - public val p: FeaturedMatrix + public val p: Matrix } /** @@ -104,12 +106,12 @@ public interface QRDecompositionFeature : MatrixFeature { /** * The orthogonal matrix in this decomposition. It may have [OrthogonalFeature]. */ - public val q: FeaturedMatrix + public val q: Matrix /** * The upper triangular matrix in this decomposition. It may have [UFeature]. */ - public val r: FeaturedMatrix + public val r: Matrix } /** @@ -122,7 +124,7 @@ public interface CholeskyDecompositionFeature : MatrixFeature { /** * The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature]. */ - public val l: FeaturedMatrix + public val l: Matrix } /** @@ -135,17 +137,17 @@ public interface SingularValueDecompositionFeature : MatrixFeature { /** * The matrix in this decomposition. It is unitary, and it consists from left singular vectors. */ - public val u: FeaturedMatrix + public val u: Matrix /** * The matrix in this decomposition. Its main diagonal elements are singular values. */ - public val s: FeaturedMatrix + public val s: Matrix /** * The matrix in this decomposition. It is unitary, and it consists from right singular vectors. */ - public val v: FeaturedMatrix + public val v: Matrix /** * The buffer of singular values of this SVD. diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt similarity index 50% rename from kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt rename to kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt index 119f5d844..bbe9c1195 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/FeaturedMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixWrapper.kt @@ -1,35 +1,57 @@ package kscience.kmath.linear +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.operations.Ring import kscience.kmath.structures.Matrix import kscience.kmath.structures.Structure2D import kscience.kmath.structures.asBuffer +import kscience.kmath.structures.getFeature import kotlin.math.sqrt +import kotlin.reflect.KClass +import kotlin.reflect.safeCast /** * A [Matrix] that holds [MatrixFeature] objects. * * @param T the type of items. */ -public interface FeaturedMatrix : Matrix { - public override val shape: IntArray get() = intArrayOf(rowNum, colNum) +public class MatrixWrapper( + public val matrix: Matrix, + public val features: Set, +) : Matrix by matrix { /** - * The set of features this matrix possesses. + * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria */ - public val features: Set + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = type.safeCast(features.find { type.isInstance(it) }) - /** - * Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure. - * - * The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to - * add only those features that are valid. - */ - public fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix - - public companion object + override fun equals(other: Any?): Boolean = matrix == other + override fun hashCode(): Int = matrix.hashCode() + override fun toString(): String { + return "MatrixWrapper(matrix=$matrix, features=$features)" + } } +/** + * Add a single feature to a [Matrix] + */ +public operator fun Matrix.plus(newFeature: MatrixFeature): MatrixWrapper = if (this is MatrixWrapper) { + MatrixWrapper(matrix, features + newFeature) +} else { + MatrixWrapper(this, setOf(newFeature)) +} + +/** + * Add a collection of features to a [Matrix] + */ +public operator fun Matrix.plus(newFeatures: Collection): MatrixWrapper = + if (this is MatrixWrapper) { + MatrixWrapper(matrix, features + newFeatures) + } else { + MatrixWrapper(this, newFeatures.toSet()) + } + public inline fun Structure2D.Companion.real( rows: Int, columns: Int, @@ -39,51 +61,37 @@ public inline fun Structure2D.Companion.real( /** * Build a square matrix from given elements. */ -public fun Structure2D.Companion.square(vararg elements: T): FeaturedMatrix { +public fun Structure2D.Companion.square(vararg elements: T): Matrix { val size: Int = sqrt(elements.size.toDouble()).toInt() require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" } val buffer = elements.asBuffer() return BufferMatrix(size, size, buffer) } -public val Matrix<*>.features: Set get() = (this as? FeaturedMatrix)?.features ?: emptySet() - -/** - * Check if matrix has the given feature class - */ -public inline fun Matrix<*>.hasFeature(): Boolean = - features.find { it is T } != null - -/** - * Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria - */ -public inline fun Matrix<*>.getFeature(): T? = - features.filterIsInstance().firstOrNull() - /** * Diagonal matrix of ones. The matrix is virtual no actual matrix is created */ -public fun > GenericMatrixContext.one(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> +public fun > GenericMatrixContext.one(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { i, j -> if (i == j) elementContext.one else elementContext.zero - } + } + UnitFeature /** * A virtual matrix of zeroes */ -public fun > GenericMatrixContext.zero(rows: Int, columns: Int): FeaturedMatrix = - VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } +public fun > GenericMatrixContext.zero(rows: Int, columns: Int): Matrix = + VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature public class TransposedFeature(public val original: Matrix) : MatrixFeature /** * Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A` */ +@OptIn(UnstableKMathAPI::class) public fun Matrix.transpose(): Matrix { return getFeature>()?.original ?: VirtualMatrix( colNum, rowNum, - setOf(TransposedFeature(this)) - ) { i, j -> get(j, i) } + ) { i, j -> get(j, i) } + TransposedFeature(this) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt index 90e251c3a..8e197672f 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/RealMatrixContext.kt @@ -1,9 +1,6 @@ package kscience.kmath.linear -import kscience.kmath.operations.RealField import kscience.kmath.structures.Matrix -import kscience.kmath.structures.MutableBuffer -import kscience.kmath.structures.MutableBufferFactory import kscience.kmath.structures.RealBuffer @Suppress("OVERRIDE_BY_INLINE") @@ -22,9 +19,9 @@ public object RealMatrixContext : MatrixContext> { produce(rowNum, colNum) { i, j -> get(i, j) } } - public fun one(rows: Int, columns: Int): FeaturedMatrix = VirtualMatrix(rows, columns, DiagonalFeature) { i, j -> + public fun one(rows: Int, columns: Int): Matrix = VirtualMatrix(rows, columns) { i, j -> if (i == j) 1.0 else 0.0 - } + } + DiagonalFeature public override infix fun Matrix.dot(other: Matrix): BufferMatrix { require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" } @@ -69,16 +66,3 @@ public object RealMatrixContext : MatrixContext> { * Partially optimized real-valued matrix */ public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext - -public fun RealMatrixContext.solveWithLUP(a: Matrix, b: Matrix): FeaturedMatrix { - // Use existing decomposition if it is provided by matrix - val bufferFactory: MutableBufferFactory = MutableBuffer.Companion::real - val decomposition = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 } - return decomposition.solveWithLUP(bufferFactory, b) -} - -/** - * Inverses a square matrix using LUP decomposition. Non square matrix will throw a error. - */ -public fun RealMatrixContext.inverseWithLUP(matrix: Matrix): FeaturedMatrix = - solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum)) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt index e0a1d0026..0269a64d1 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/VirtualMatrix.kt @@ -5,31 +5,16 @@ import kscience.kmath.structures.Matrix public class VirtualMatrix( override val rowNum: Int, override val colNum: Int, - override val features: Set = emptySet(), public val generator: (i: Int, j: Int) -> T -) : FeaturedMatrix { - public constructor( - rowNum: Int, - colNum: Int, - vararg features: MatrixFeature, - generator: (i: Int, j: Int) -> T - ) : this( - rowNum, - colNum, - setOf(*features), - generator - ) +) : Matrix { override val shape: IntArray get() = intArrayOf(rowNum, colNum) override operator fun get(i: Int, j: Int): T = generator(i, j) - override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix = - VirtualMatrix(rowNum, colNum, this.features + features, generator) - override fun equals(other: Any?): Boolean { if (this === other) return true - if (other !is FeaturedMatrix<*>) return false + if (other !is Matrix<*>) return false if (rowNum != other.rowNum) return false if (colNum != other.colNum) return false @@ -40,21 +25,9 @@ public class VirtualMatrix( override fun hashCode(): Int { var result = rowNum result = 31 * result + colNum - result = 31 * result + features.hashCode() result = 31 * result + generator.hashCode() return result } - public companion object { - /** - * Wrap a matrix adding additional features to it - */ - public fun wrap(matrix: Matrix, vararg features: MatrixFeature): FeaturedMatrix { - return if (matrix is VirtualMatrix) - VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator) - else - VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> matrix[i, j] } - } - } } diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt index de3818aa6..0440d74e8 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/numbers.kt @@ -179,13 +179,15 @@ public object FloatField : ExtendedField, Norm { * A field for [Int] without boxing. Does not produce corresponding ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object IntRing : Ring, Norm { +public object IntRing : Ring, Norm, NumericAlgebra { public override val zero: Int get() = 0 public override val one: Int get() = 1 + override fun number(value: Number): Int = value.toInt() + public override inline fun add(a: Int, b: Int): Int = a + b public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a @@ -203,13 +205,15 @@ public object IntRing : Ring, Norm { * A field for [Short] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object ShortRing : Ring, Norm { +public object ShortRing : Ring, Norm, NumericAlgebra { public override val zero: Short get() = 0 public override val one: Short get() = 1 + override fun number(value: Number): Short = value.toShort() + public override inline fun add(a: Short, b: Short): Short = (a + b).toShort() public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() @@ -227,13 +231,15 @@ public object ShortRing : Ring, Norm { * A field for [Byte] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object ByteRing : Ring, Norm { +public object ByteRing : Ring, Norm, NumericAlgebra { public override val zero: Byte get() = 0 public override val one: Byte get() = 1 + override fun number(value: Number): Byte = value.toByte() + public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() @@ -251,13 +257,15 @@ public object ByteRing : Ring, Norm { * A field for [Double] without boxing. Does not produce appropriate ring element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") -public object LongRing : Ring, Norm { +public object LongRing : Ring, Norm, NumericAlgebra { public override val zero: Long get() = 0L public override val one: Long get() = 1L + override fun number(value: Number): Long = value.toLong() + public override inline fun add(a: Long, b: Long): Long = a + b public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt index 5c5d28882..64e723581 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/NDStructure.kt @@ -1,5 +1,6 @@ package kscience.kmath.structures +import kscience.kmath.misc.UnstableKMathAPI import kotlin.jvm.JvmName import kotlin.native.concurrent.ThreadLocal import kotlin.reflect.KClass @@ -42,6 +43,13 @@ public interface NDStructure { public override fun equals(other: Any?): Boolean public override fun hashCode(): Int + /** + * Feature is additional property or hint that does not directly affect the structure, but could in some cases help + * optimize operations and performance. If the feature is not present, null is defined. + */ + @UnstableKMathAPI + public fun getFeature(type: KClass): T? = null + public companion object { /** * Indicates whether some [NDStructure] is equal to another one. @@ -121,6 +129,9 @@ public interface NDStructure { */ public operator fun NDStructure.get(vararg index: Int): T = get(index) +@UnstableKMathAPI +public inline fun NDStructure<*>.getFeature(): T? = getFeature(T::class) + /** * Represents mutable [NDStructure]. */ diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt index bac7d3389..d20e9e53b 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/structures/Structure2D.kt @@ -9,12 +9,14 @@ public interface Structure2D : NDStructure { /** * The number of rows in this structure. */ - public val rowNum: Int get() = shape[0] + public val rowNum: Int /** * The number of columns in this structure. */ - public val colNum: Int get() = shape[1] + public val colNum: Int + + public override val shape: IntArray get() = intArrayOf(rowNum, colNum) /** * The buffer of rows of this structure. It gets elements from the structure dynamically. @@ -56,6 +58,9 @@ public interface Structure2D : NDStructure { private inline class Structure2DWrapper(val structure: NDStructure) : Structure2D { override val shape: IntArray get() = structure.shape + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] + override operator fun get(i: Int, j: Int): T = structure[i, j] override fun elements(): Sequence> = structure.elements() diff --git a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt index 0422d11b2..0244eae7f 100644 --- a/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt +++ b/kmath-dimensions/src/commonMain/kotlin/kscience/kmath/dimensions/Wrappers.kt @@ -40,6 +40,8 @@ public inline class DMatrixWrapper( private val structure: Structure2D, ) : DMatrix { override val shape: IntArray get() = structure.shape + override val rowNum: Int get() = shape[0] + override val colNum: Int get() = shape[1] override operator fun get(i: Int, j: Int): T = structure[i, j] } @@ -147,6 +149,7 @@ public inline fun DMatrixContext.one(): DMatrix< if (i == j) 1.0 else 0.0 } -public inline fun DMatrixContext.zero(): DMatrix = produce { _, _ -> - 0.0 -} \ No newline at end of file +public inline fun DMatrixContext.zero(): DMatrix = + produce { _, _ -> + 0.0 + } \ No newline at end of file diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt index a7d571b58..1c2ded447 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrix.kt @@ -1,10 +1,13 @@ package kscience.kmath.ejml import kscience.kmath.linear.* -import kscience.kmath.structures.NDStructure +import kscience.kmath.misc.UnstableKMathAPI +import kscience.kmath.structures.Matrix import kscience.kmath.structures.RealBuffer import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix +import kotlin.reflect.KClass +import kotlin.reflect.cast /** * Represents featured matrix over EJML [SimpleMatrix]. @@ -12,85 +15,65 @@ import org.ejml.simple.SimpleMatrix * @property origin the underlying [SimpleMatrix]. * @author Iaroslav Postovalov */ -public class EjmlMatrix( +public inline class EjmlMatrix( public val origin: SimpleMatrix, - features: Set = emptySet() -) : FeaturedMatrix { - public override val rowNum: Int - get() = origin.numRows() +) : Matrix { + public override val rowNum: Int get() = origin.numRows() - public override val colNum: Int - get() = origin.numCols() + public override val colNum: Int get() = origin.numCols() - public override val shape: IntArray by lazy { intArrayOf(rowNum, colNum) } - - public override val features: Set = hashSetOf( - object : InverseMatrixFeature { - override val inverse: FeaturedMatrix by lazy { EjmlMatrix(origin.invert()) } - }, - - object : DeterminantFeature { + @UnstableKMathAPI + override fun getFeature(type: KClass): T? = when (type) { + InverseMatrixFeature::class -> object : InverseMatrixFeature { + override val inverse: Matrix by lazy { EjmlMatrix(origin.invert()) } + } + DeterminantFeature::class -> object : DeterminantFeature { override val determinant: Double by lazy(origin::determinant) - }, - - object : SingularValueDecompositionFeature { + } + SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature { private val svd by lazy { DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false) .apply { decompose(origin.ddrm.copy()) } } - override val u: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) } - override val s: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) } - override val v: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } + override val u: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) } + override val s: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) } + override val v: Matrix by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) } override val singularValues: Point by lazy { RealBuffer(svd.singularValues) } - }, - - object : QRDecompositionFeature { + } + QRDecompositionFeature::class -> object : QRDecompositionFeature { private val qr by lazy { DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) } } - override val q: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } - override val r: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } - }, - - object : CholeskyDecompositionFeature { - override val l: FeaturedMatrix by lazy { + override val q: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) } + override val r: Matrix by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) } + } + CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature { + override val l: Matrix by lazy { val cholesky = DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) } - EjmlMatrix(SimpleMatrix(cholesky.getT(null)), setOf(LFeature)) + EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature } - }, - - object : LupDecompositionFeature { + } + LupDecompositionFeature::class -> object : LupDecompositionFeature { private val lup by lazy { DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) } } - override val l: FeaturedMatrix by lazy { - EjmlMatrix(SimpleMatrix(lup.getLower(null)), setOf(LFeature)) + override val l: Matrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature } - override val u: FeaturedMatrix by lazy { - EjmlMatrix(SimpleMatrix(lup.getUpper(null)), setOf(UFeature)) + override val u: Matrix by lazy { + EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature } - override val p: FeaturedMatrix by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } - }, - ) union features - - public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix = - EjmlMatrix(origin, this.features + features) + override val p: Matrix by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) } + } + else -> null + }?.let{type.cast(it)} public override operator fun get(i: Int, j: Int): Double = origin[i, j] - - public override fun equals(other: Any?): Boolean { - if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0) - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - public override fun hashCode(): Int = origin.hashCode() - - public override fun toString(): String = "EjmlMatrix($origin)" } diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index c8789a4a7..7198bbd0d 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -2,23 +2,29 @@ package kscience.kmath.ejml import kscience.kmath.linear.InverseMatrixFeature import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.MatrixWrapper import kscience.kmath.linear.Point -import kscience.kmath.linear.getFeature +import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Matrix +import kscience.kmath.structures.getFeature import org.ejml.simple.SimpleMatrix -/** - * Converts this matrix to EJML one. - */ -public fun Matrix.toEjml(): EjmlMatrix = - if (this is EjmlMatrix) this else EjmlMatrixContext.produce(rowNum, colNum) { i, j -> get(i, j) } - /** * Represents context of basic operations operating with [EjmlMatrix]. * * @author Iaroslav Postovalov */ public object EjmlMatrixContext : MatrixContext { + + /** + * Converts this matrix to EJML one. + */ + public fun Matrix.toEjml(): EjmlMatrix = when { + this is EjmlMatrix -> this + this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix + else -> produce(rowNum, colNum) { i, j -> get(i, j) } + } + /** * Converts this vector to EJML one. */ @@ -80,6 +86,7 @@ public fun EjmlMatrixContext.solve(a: Matrix, b: Matrix): EjmlMa public fun EjmlMatrixContext.solve(a: Matrix, b: Point): EjmlVector = EjmlVector(a.toEjml().origin.solve(b.toEjml().origin)) +@OptIn(UnstableKMathAPI::class) public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature>()!!.inverse as EjmlMatrix public fun EjmlMatrixContext.inverse(matrix: Matrix): Matrix = matrix.toEjml().inverted() \ No newline at end of file diff --git a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt index 70b82a3cb..30d146779 100644 --- a/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt +++ b/kmath-ejml/src/test/kotlin/kscience/kmath/ejml/EjmlMatrixTest.kt @@ -3,7 +3,8 @@ package kscience.kmath.ejml import kscience.kmath.linear.DeterminantFeature import kscience.kmath.linear.LupDecompositionFeature import kscience.kmath.linear.MatrixFeature -import kscience.kmath.linear.getFeature +import kscience.kmath.linear.plus +import kscience.kmath.structures.getFeature import org.ejml.dense.row.factory.DecompositionFactory_DDRM import org.ejml.simple.SimpleMatrix import kotlin.random.Random @@ -58,7 +59,7 @@ internal class EjmlMatrixTest { @Test fun suggestFeature() { - assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature()) + assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature()) } @Test diff --git a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt index 772abfbed..274030aff 100644 --- a/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt +++ b/kmath-for-real/src/commonMain/kotlin/kscience/kmath/real/RealMatrix.kt @@ -1,8 +1,12 @@ package kscience.kmath.real -import kscience.kmath.linear.* +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.VirtualMatrix +import kscience.kmath.linear.inverseWithLUP +import kscience.kmath.linear.real import kscience.kmath.misc.UnstableKMathAPI import kscience.kmath.structures.Buffer +import kscience.kmath.structures.Matrix import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.asIterable import kotlin.math.pow @@ -19,7 +23,7 @@ import kotlin.math.pow * Functions that help create a real (Double) matrix */ -public typealias RealMatrix = FeaturedMatrix +public typealias RealMatrix = Matrix public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix = MatrixContext.real.produce(rowNum, colNum, initializer) diff --git a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt index 5c33b76a9..a89f99b3c 100644 --- a/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt +++ b/kmath-for-real/src/commonTest/kotlin/kaceince/kmath/real/RealMatrixTest.kt @@ -1,6 +1,5 @@ package kaceince.kmath.real -import kscience.kmath.linear.VirtualMatrix import kscience.kmath.linear.build import kscience.kmath.real.* import kscience.kmath.structures.Matrix @@ -42,7 +41,7 @@ internal class RealMatrixTest { 1.0, 0.0, 0.0, 0.0, 1.0, 2.0 ) - assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3)) + assertEquals(matrix2, matrix1.repeatStackVertical(3)) } @Test