From 5c82a5e1fa2dced56af94a12cd5ca8f4e0a0ff51 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Sat, 18 Nov 2023 22:29:59 +0300 Subject: [PATCH] 0.4 WIP --- CHANGELOG.md | 3 +- .../space/kscience/attributes/Attribute.kt | 11 --- .../kscience/attributes/AttributeContainer.kt | 10 +- .../space/kscience/attributes/Attributes.kt | 42 ++++---- .../kscience/attributes/AttributesBuilder.kt | 33 ++++--- .../attributes/PolymorphicAttribute.kt | 31 ++++++ .../kmath/ejml/codegen/ejmlCodegen.kt | 12 ++- .../space/kscience/kmath/fit/chiSquared.kt | 4 +- .../kotlin/space/kscience/kmath/fit/qowFit.kt | 4 +- .../space/kscience/kmath/ast/TypedMst.kt | 36 +++++-- .../kscience/kmath/ast/evaluateConstants.kt | 9 +- .../space/kscience/kmath/estree/estree.kt | 4 +- .../kmath/estree/internal/ESTreeBuilder.kt | 15 ++- .../kotlin/space/kscience/kmath/asm/asm.kt | 2 +- .../commons/expressions/CmDsExpression.kt | 19 +++- .../integration/CMGaussRuleIntegrator.kt | 4 +- .../kmath/commons/integration/CMIntegrator.kt | 17 ++-- .../kscience/kmath/commons/linear/CMMatrix.kt | 38 ++++---- .../kmath/commons/optimization/CMOptimizer.kt | 35 +++---- .../commons/optimization/OptimizeTest.kt | 4 +- .../kscience/kmath/expressions/DSAlgebra.kt | 11 ++- .../kscience/kmath/expressions/Expression.kt | 29 +++++- .../expressions/ExpressionWithDefault.kt | 8 ++ .../FunctionalExpressionAlgebra.kt | 19 ++-- .../space/kscience/kmath/expressions/MST.kt | 2 +- .../kmath/expressions/SimpleAutoDiff.kt | 5 +- .../kscience/kmath/linear/LinearSolver.kt | 2 +- .../kscience/kmath/linear/LinearSpace.kt | 21 ++-- .../kscience/kmath/linear/LupDecomposition.kt | 41 ++++---- .../kscience/kmath/linear/MatrixWrapper.kt | 10 +- .../kscience/kmath/linear/matrixAttributes.kt | 44 +++++---- .../kscience/kmath/ejml/EjmlLinearSpace.kt | 2 +- .../space/kscience/kmath/ejml/_generated.kt | 39 +++++++- .../kmath/integration/GaussIntegrator.kt | 4 +- .../kscience/kmath/integration/Integrand.kt | 2 +- .../integration/MultivariateIntegrand.kt | 4 +- .../kmath/integration/UnivariateIntegrand.kt | 10 +- kmath-jafama/build.gradle.kts | 2 +- .../kscience/kmath/jafama/KMathJafama.kt | 12 ++- .../kmath/multik/MultikDoubleAlgebra.kt | 2 +- .../kmath/multik/MultikFloatAlgebra.kt | 2 +- .../kscience/kmath/multik/MultikIntAlgebra.kt | 2 +- .../kmath/multik/MultikLongAlgebra.kt | 2 +- .../kmath/multik/MultikShortAlgebra.kt | 2 +- .../kscience/kmath/multik/MultikTensor.kt | 20 ++++ .../kmath/multik/MultikTensorAlgebra.kt | 14 +-- .../optimization/FunctionOptimization.kt | 53 ++++++---- .../kmath/optimization/OptimizationBuilder.kt | 96 ------------------- .../kmath/optimization/OptimizationProblem.kt | 61 +++++------- .../kmath/optimization/QowOptimizer.kt | 27 +++--- .../kscience/kmath/optimization/XYFit.kt | 64 ++++++++----- .../kmath/optimization/logLikelihood.kt | 19 +++- .../kmath/tensors/core/DoubleTensor.kt | 4 + .../kscience/kmath/tensors/core/IntTensor.kt | 6 ++ 54 files changed, 541 insertions(+), 433 deletions(-) create mode 100644 attributes-kt/src/commonMain/kotlin/space/kscience/attributes/PolymorphicAttribute.kt delete mode 100644 kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index ef5e0490c..58df6f1bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ## Unreleased ### Added -- Explicit `SafeType` for algebras and buffers. +- Reification. Explicit `SafeType` for algebras and buffers. - Integer division algebras. - Float32 geometries. - New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers. @@ -16,6 +16,7 @@ - kmath-geometry is split into `euclidean2d` and `euclidean3d` - Features replaced with Attributes. - Transposed refactored. +- Kmath-memory is moved on top of core. ### Deprecated diff --git a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attribute.kt b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attribute.kt index a507cd698..dda7c6ed5 100644 --- a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attribute.kt +++ b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attribute.kt @@ -24,14 +24,3 @@ public interface AttributeWithDefault : Attribute { */ public interface SetAttribute : Attribute> -/** - * An attribute that has a type parameter for value - * @param type parameter-type - */ -public abstract class PolymorphicAttribute(public val type: SafeType) : Attribute { - override fun equals(other: Any?): Boolean = other != null && - (this::class == other::class) && - (other as? PolymorphicAttribute<*>)?.type == this.type - - override fun hashCode(): Int = this::class.hashCode() + type.hashCode() -} diff --git a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributeContainer.kt b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributeContainer.kt index 69b050649..19e5c224a 100644 --- a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributeContainer.kt +++ b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributeContainer.kt @@ -6,8 +6,14 @@ package space.kscience.attributes /** - * A container for attributes. [attributes] could be made mutable by implementation + * A container for [Attributes] */ public interface AttributeContainer { public val attributes: Attributes -} \ No newline at end of file +} + +/** + * A scope, where attribute keys could be resolved + */ +public interface AttributeScope + diff --git a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attributes.kt b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attributes.kt index b50436dd2..6c8dabc50 100644 --- a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attributes.kt +++ b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/Attributes.kt @@ -7,21 +7,27 @@ package space.kscience.attributes import kotlin.jvm.JvmInline -@JvmInline -public value class Attributes internal constructor(public val content: Map, Any?>) { +/** + * A set of attributes. The implementation must guarantee that [content] keys correspond to its value types. + */ +public interface Attributes { + public val content: Map, Any?> public val keys: Set> get() = content.keys @Suppress("UNCHECKED_CAST") public operator fun get(attribute: Attribute): T? = content[attribute] as? T - override fun toString(): String = "Attributes(value=${content.entries})" - public companion object { - public val EMPTY: Attributes = Attributes(emptyMap()) + public val EMPTY: Attributes = AttributesImpl(emptyMap()) } } +@JvmInline +internal value class AttributesImpl(override val content: Map, Any?>) : Attributes { + override fun toString(): String = "Attributes(value=${content.entries})" +} + public fun Attributes.isEmpty(): Boolean = content.isEmpty() /** @@ -33,19 +39,19 @@ public fun Attributes.getOrDefault(attribute: AttributeWithDefault): T = * Check if there is an attribute that matches given key by type and adheres to [predicate]. */ @Suppress("UNCHECKED_CAST") -public inline fun > Attributes.any(predicate: (value: T) -> Boolean): Boolean = +public inline fun > Attributes.hasAny(predicate: (value: T) -> Boolean): Boolean = content.any { (mapKey, mapValue) -> mapKey is A && predicate(mapValue as T) } /** * Check if there is an attribute of given type (subtypes included) */ -public inline fun > Attributes.any(): Boolean = +public inline fun > Attributes.hasAny(): Boolean = content.any { (mapKey, _) -> mapKey is A } /** * Check if [Attributes] contains a flag. Multiple keys that are instances of a flag could be present */ -public inline fun Attributes.has(): Boolean = +public inline fun Attributes.hasFlag(): Boolean = content.keys.any { it is A } /** @@ -54,7 +60,7 @@ public inline fun Attributes.has(): Boolean = public fun > Attributes.withAttribute( attribute: A, attrValue: T, -): Attributes = Attributes(content + (attribute to attrValue)) +): Attributes = AttributesImpl(content + (attribute to attrValue)) public fun > Attributes.withAttribute(attribute: A): Attributes = withAttribute(attribute, Unit) @@ -62,7 +68,7 @@ public fun > Attributes.withAttribute(attribute: A): Attribu /** * Create a new [Attributes] by modifying the current one */ -public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes = Attributes { +public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes = Attributes { from(this@modify) block() } @@ -70,7 +76,7 @@ public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes = /** * Create new [Attributes] by removing [attribute] key */ -public fun Attributes.withoutAttribute(attribute: Attribute<*>): Attributes = Attributes(content.minus(attribute)) +public fun Attributes.withoutAttribute(attribute: Attribute<*>): Attributes = AttributesImpl(content.minus(attribute)) /** * Add an element to a [SetAttribute] @@ -80,7 +86,7 @@ public fun > Attributes.withAttributeElement( attrValue: T, ): Attributes { val currentSet: Set = get(attribute) ?: emptySet() - return Attributes( + return AttributesImpl( content + (attribute to (currentSet + attrValue)) ) } @@ -93,9 +99,7 @@ public fun > Attributes.withoutAttributeElement( attrValue: T, ): Attributes { val currentSet: Set = get(attribute) ?: emptySet() - return Attributes( - content + (attribute to (currentSet - attrValue)) - ) + return AttributesImpl(content + (attribute to (currentSet - attrValue))) } /** @@ -104,13 +108,13 @@ public fun > Attributes.withoutAttributeElement( public fun > Attributes( attribute: A, attrValue: T, -): Attributes = Attributes(mapOf(attribute to attrValue)) +): Attributes = AttributesImpl(mapOf(attribute to attrValue)) /** * Create Attributes with a single [Unit] valued attribute */ public fun > Attributes( - attribute: A -): Attributes = Attributes(mapOf(attribute to Unit)) + attribute: A, +): Attributes = AttributesImpl(mapOf(attribute to Unit)) -public operator fun Attributes.plus(other: Attributes): Attributes = Attributes(content + other.content) \ No newline at end of file +public operator fun Attributes.plus(other: Attributes): Attributes = AttributesImpl(content + other.content) \ No newline at end of file diff --git a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributesBuilder.kt b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributesBuilder.kt index 6d74b90c1..0acf4e004 100644 --- a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributesBuilder.kt +++ b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/AttributesBuilder.kt @@ -10,19 +10,24 @@ package space.kscience.attributes * * @param O type marker of an owner object, for which these attributes are made */ -public class TypedAttributesBuilder internal constructor(private val map: MutableMap, Any?>) { +public class AttributesBuilder internal constructor( + private val map: MutableMap, Any?>, +) : Attributes { public constructor() : this(mutableMapOf()) - @Suppress("UNCHECKED_CAST") - public operator fun get(attribute: Attribute): T? = map[attribute] as? T + override val content: Map, Any?> get() = map + + public operator fun set(attribute: Attribute, value: T?) { + if (value == null) { + map.remove(attribute) + } else { + map[attribute] = value + } + } public operator fun Attribute.invoke(value: V?) { - if (value == null) { - map.remove(this) - } else { - map[this] = value - } + set(this, value) } public fun from(attributes: Attributes) { @@ -46,14 +51,8 @@ public class TypedAttributesBuilder internal constructor(private val map: map[this] = currentSet - attrValue } - public fun build(): Attributes = Attributes(map) + public fun build(): Attributes = AttributesImpl(map) } -public typealias AttributesBuilder = TypedAttributesBuilder - -public fun AttributesBuilder( - attributes: Attributes, -): AttributesBuilder = AttributesBuilder(attributes.content.toMutableMap()) - -public inline fun Attributes(builder: AttributesBuilder.() -> Unit): Attributes = - AttributesBuilder().apply(builder).build() \ No newline at end of file +public inline fun Attributes(builder: AttributesBuilder.() -> Unit): Attributes = + AttributesBuilder().apply(builder).build() \ No newline at end of file diff --git a/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/PolymorphicAttribute.kt b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/PolymorphicAttribute.kt new file mode 100644 index 000000000..b61d4c477 --- /dev/null +++ b/attributes-kt/src/commonMain/kotlin/space/kscience/attributes/PolymorphicAttribute.kt @@ -0,0 +1,31 @@ +/* + * Copyright 2018-2023 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.attributes + +/** + * An attribute that has a type parameter for value + * @param type parameter-type + */ +public abstract class PolymorphicAttribute(public val type: SafeType) : Attribute { + override fun equals(other: Any?): Boolean = other != null && + (this::class == other::class) && + (other as? PolymorphicAttribute<*>)?.type == this.type + + override fun hashCode(): Int = this::class.hashCode() + type.hashCode() +} + + +/** + * Get a polymorphic attribute using attribute factory + */ +public operator fun Attributes.get(attributeKeyBuilder: () -> PolymorphicAttribute): T? = get(attributeKeyBuilder()) + +/** + * Set a polymorphic attribute using its factory + */ +public operator fun AttributesBuilder.set(attributeKeyBuilder: () -> PolymorphicAttribute, value: T) { + set(attributeKeyBuilder(), value) +} diff --git a/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt b/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt index b8c895196..bb46a085b 100644 --- a/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt +++ b/buildSrc/src/main/kotlin/space/kscience/kmath/ejml/codegen/ejmlCodegen.kt @@ -19,6 +19,8 @@ public class Ejml${type}Vector(override val origin: M) require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } } + override val type: SafeType<${type}> get() = safeTypeOf() + override operator fun get(index: Int): $type = origin[0, index] }""" appendLine(text) @@ -30,6 +32,8 @@ private fun Appendable.appendEjmlMatrix(type: String, ejmlMatrixType: String) { * [EjmlMatrix] specialization for [$type]. */ public class Ejml${type}Matrix(override val origin: M) : EjmlMatrix<$type, M>(origin) { + override val type: SafeType<${type}> get() = safeTypeOf() + override operator fun get(i: Int, j: Int): $type = origin[i, j] }""" appendLine(text) @@ -46,7 +50,9 @@ private fun Appendable.appendEjmlLinearSpace( denseOps: String, isDense: Boolean, ) { - @Language("kotlin") val text = """/** + @Language("kotlin") val text = """ + +/** * [EjmlLinearSpace] implementation based on [CommonOps_$ops], [DecompositionFactory_${ops}] operations and * [${ejmlMatrixType}] matrices. */ @@ -56,7 +62,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra}, */ override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra - override val elementType: KType get() = typeOf<$type>() + override val type: SafeType<${type}> get() = safeTypeOf() @Suppress("UNCHECKED_CAST") override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when { @@ -385,6 +391,8 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.linear.* import space.kscience.kmath.linear.Matrix import space.kscience.kmath.UnstableKMathAPI diff --git a/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt b/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt index 258ed0c84..f2c0c7cf4 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt @@ -15,7 +15,7 @@ import space.kscience.kmath.operations.asIterable import space.kscience.kmath.operations.toList import space.kscience.kmath.optimization.FunctionOptimizationTarget import space.kscience.kmath.optimization.optimizeWith -import space.kscience.kmath.optimization.resultPoint +import space.kscience.kmath.optimization.result import space.kscience.kmath.optimization.resultValue import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.real.DoubleVector @@ -98,7 +98,7 @@ suspend fun main() { scatter { mode = ScatterMode.lines x(x) - y(x.map { result.resultPoint[a]!! * it.pow(2) + result.resultPoint[b]!! * it + 1 }) + y(x.map { result.result[a]!! * it.pow(2) + result.result[b]!! * it + 1 }) name = "fit" } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt b/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt index fe7f48b72..a092b8870 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt @@ -94,13 +94,13 @@ suspend fun main() { scatter { mode = ScatterMode.lines x(x) - y(x.map { result.model(result.startPoint + result.resultPoint + (Symbol.x to it)) }) + y(x.map { result.model(result.startPoint + result.result + (Symbol.x to it)) }) name = "fit" } } br() h3 { - +"Fit result: ${result.resultPoint}" + +"Fit result: ${result.result}" } h3 { +"Chi2/dof = ${result.chiSquaredOrNull!! / result.dof}" diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt index e82f7a3ab..d824a652e 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/TypedMst.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.ast +import space.kscience.attributes.SafeType +import space.kscience.attributes.WithType import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.operations.Algebra @@ -15,7 +17,7 @@ import space.kscience.kmath.operations.NumericAlgebra * * @param T the type. */ -public sealed interface TypedMst { +public sealed interface TypedMst : WithType { /** * A node containing a unary operation. * @@ -24,8 +26,13 @@ public sealed interface TypedMst { * @property function The function implementing this operation. * @property value The argument of this operation. */ - public class Unary(public val operation: String, public val function: (T) -> T, public val value: TypedMst) : - TypedMst { + public class Unary( + public val operation: String, + public val function: (T) -> T, + public val value: TypedMst, + ) : TypedMst { + override val type: SafeType get() = value.type + override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false @@ -59,6 +66,13 @@ public sealed interface TypedMst { public val left: TypedMst, public val right: TypedMst, ) : TypedMst { + + init { + require(left.type==right.type){"Left and right expressions must be of the same type"} + } + + override val type: SafeType get() = left.type + override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false @@ -89,7 +103,12 @@ public sealed interface TypedMst { * @property value The held value. * @property number The number this value corresponds. */ - public class Constant(public val value: T, public val number: Number?) : TypedMst { + public class Constant( + override val type: SafeType, + public val value: T, + public val number: Number?, + ) : TypedMst { + override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false @@ -114,7 +133,7 @@ public sealed interface TypedMst { * @param T the type. * @property symbol The symbol of the variable. */ - public class Variable(public val symbol: Symbol) : TypedMst { + public class Variable(override val type: SafeType, public val symbol: Symbol) : TypedMst { override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false @@ -167,6 +186,7 @@ public fun TypedMst.interpret(algebra: Algebra, vararg arguments: Pair /** * Interpret this [TypedMst] node as expression. */ -public fun TypedMst.toExpression(algebra: Algebra): Expression = Expression { arguments -> - interpret(algebra, arguments) -} +public fun TypedMst.toExpression(algebra: Algebra): Expression = + Expression(algebra.type) { arguments -> + interpret(algebra, arguments) + } diff --git a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt index 8fc5a6aaf..4298a9788 100644 --- a/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt +++ b/kmath-ast/src/commonMain/kotlin/space/kscience/kmath/ast/evaluateConstants.kt @@ -16,6 +16,7 @@ import space.kscience.kmath.operations.bindSymbolOrNull */ public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (this) { is MST.Numeric -> TypedMst.Constant( + algebra.type, (algebra as? NumericAlgebra)?.number(value) ?: error("Numeric nodes are not supported by $algebra"), value, ) @@ -27,7 +28,7 @@ public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (t arg.value, ) - TypedMst.Constant(value, if (value is Number) value else null) + TypedMst.Constant(algebra.type, value, if (value is Number) value else null) } else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg) @@ -59,7 +60,7 @@ public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (t ) } - TypedMst.Constant(value, if (value is Number) value else null) + TypedMst.Constant(algebra.type, value, if (value is Number) value else null) } algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary( @@ -84,8 +85,8 @@ public fun MST.evaluateConstants(algebra: Algebra): TypedMst = when (t val boundSymbol = algebra.bindSymbolOrNull(this) if (boundSymbol != null) - TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null) + TypedMst.Constant(algebra.type, boundSymbol, if (boundSymbol is Number) boundSymbol else null) else - TypedMst.Variable(this) + TypedMst.Variable(algebra.type, this) } } diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt index 87c2df2d2..33626eaa1 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/estree.kt @@ -22,7 +22,7 @@ import space.kscience.kmath.operations.Algebra @OptIn(UnstableKMathAPI::class) public fun MST.compileToExpression(algebra: Algebra): Expression { val typed = evaluateConstants(algebra) - if (typed is TypedMst.Constant) return Expression { typed.value } + if (typed is TypedMst.Constant) return Expression(algebra.type) { typed.value } fun ESTreeBuilder.visit(node: TypedMst): BaseExpression = when (node) { is TypedMst.Constant -> constant(node.value) @@ -36,7 +36,7 @@ public fun MST.compileToExpression(algebra: Algebra): Expression ) } - return ESTreeBuilder { visit(typed) }.instance + return ESTreeBuilder(algebra.type) { visit(typed) }.instance } /** diff --git a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt index 1517cdef2..ed2b62336 100644 --- a/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt +++ b/kmath-ast/src/jsMain/kotlin/space/kscience/kmath/estree/internal/ESTreeBuilder.kt @@ -5,13 +5,22 @@ package space.kscience.kmath.estree.internal +import space.kscience.attributes.SafeType +import space.kscience.attributes.WithType import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.internal.astring.generate import space.kscience.kmath.internal.estree.* -internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExpression) { - private class GeneratedExpression(val executable: dynamic, val constants: Array) : Expression { +internal class ESTreeBuilder( + override val type: SafeType, + val bodyCallback: ESTreeBuilder.() -> BaseExpression, +) : WithType { + private class GeneratedExpression( + override val type: SafeType, + val executable: dynamic, + val constants: Array, + ) : Expression { @Suppress("UNUSED_VARIABLE") override fun invoke(arguments: Map): T { val e = executable @@ -30,7 +39,7 @@ internal class ESTreeBuilder(val bodyCallback: ESTreeBuilder.() -> BaseExp ) val code = generate(node) - GeneratedExpression(js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray()) + GeneratedExpression(type, js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray()) } private val constants = mutableListOf() diff --git a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt index 97fe91ee4..50e6a2001 100644 --- a/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt +++ b/kmath-ast/src/jvmMain/kotlin/space/kscience/kmath/asm/asm.kt @@ -29,7 +29,7 @@ import space.kscience.kmath.operations.Int64Ring @PublishedApi internal fun MST.compileWith(type: Class, algebra: Algebra): Expression { val typed = evaluateConstants(algebra) - if (typed is TypedMst.Constant) return Expression { typed.value } + if (typed is TypedMst.Constant) return Expression(algebra.type) { typed.value } fun GenericAsmBuilder.variablesVisitor(node: TypedMst): Unit = when (node) { is TypedMst.Unary -> variablesVisitor(node.value) diff --git a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/expressions/CmDsExpression.kt b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/expressions/CmDsExpression.kt index 38eaf8868..65ae8dd2d 100644 --- a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/expressions/CmDsExpression.kt +++ b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/expressions/CmDsExpression.kt @@ -8,10 +8,13 @@ package space.kscience.kmath.commons.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.expressions.* +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.NumbersAddOps +import space.kscience.kmath.structures.MutableBufferFactory /** * A field over commons-math [DerivativeStructure]. @@ -26,6 +29,9 @@ public class CmDsField( bindings: Map, ) : ExtendedField, ExpressionAlgebra, NumbersAddOps { + + override val bufferFactory: MutableBufferFactory = MutableBufferFactory() + public val numberOfVariables: Int = bindings.size override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } @@ -77,7 +83,9 @@ public class CmDsField( override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) - override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right) + override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = + left.multiply(right) + override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right) override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() @@ -113,8 +121,8 @@ public class CmDsField( */ @Deprecated("Use generic DSAlgebra from the core") public object CmDsProcessor : AutoDiffProcessor { - override fun differentiate( - function: CmDsField.() -> DerivativeStructure, + override fun differentiate( + function: CmDsField.() -> DerivativeStructure, ): CmDsExpression = CmDsExpression(function) } @@ -125,13 +133,16 @@ public object CmDsProcessor : AutoDiffProcessor DerivativeStructure, ) : DifferentiableExpression { + + override val type: SafeType get() = DoubleField.type + override operator fun invoke(arguments: Map): Double = CmDsField(0, arguments).function().value /** * Get the derivative expression with given orders */ - override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + override fun derivativeOrNull(symbols: List): Expression = Expression(type) { arguments -> with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) } } } diff --git a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt index a3fc49d32..f087987bf 100644 --- a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt +++ b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMGaussRuleIntegrator.kt @@ -16,7 +16,7 @@ public class CMGaussRuleIntegrator( private var type: GaussRule = GaussRule.LEGENDRE, ) : UnivariateIntegrator { - override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { val range = integrand[IntegrationRange] ?: error("Integration range is not provided") val integrator: GaussIntegrator = getIntegrator(range) @@ -79,7 +79,7 @@ public class CMGaussRuleIntegrator( numPoints: Int = 100, type: GaussRule = GaussRule.LEGENDRE, function: (Double) -> Double, - ): Double = CMGaussRuleIntegrator(numPoints, type).process( + ): Double = CMGaussRuleIntegrator(numPoints, type).integrate( UnivariateIntegrand({IntegrationRange(range)},function) ).value } diff --git a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt index 2cc60fb77..ce3dabd08 100644 --- a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt +++ b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/integration/CMIntegrator.kt @@ -7,6 +7,7 @@ package space.kscience.kmath.commons.integration import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator import org.apache.commons.math3.analysis.integration.SimpsonIntegrator +import space.kscience.attributes.Attributes import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.integration.* import org.apache.commons.math3.analysis.integration.UnivariateIntegrator as CMUnivariateIntegrator @@ -19,7 +20,7 @@ public class CMIntegrator( public val integratorBuilder: (Integrand) -> CMUnivariateIntegrator, ) : UnivariateIntegrator { - override fun process(integrand: UnivariateIntegrand): UnivariateIntegrand { + override fun integrate(integrand: UnivariateIntegrand): UnivariateIntegrand { val integrator = integratorBuilder(integrand) val maxCalls = integrand[IntegrandMaxCalls] ?: defaultMaxCalls val remainingCalls = maxCalls - integrand.calls @@ -73,15 +74,9 @@ public class CMIntegrator( } @UnstableKMathAPI -public var MutableList.targetAbsoluteAccuracy: Double? - get() = filterIsInstance().lastOrNull()?.accuracy - set(value) { - value?.let { add(IntegrandAbsoluteAccuracy(value)) } - } +public val Attributes.targetAbsoluteAccuracy: Double? + get() = get(IntegrandAbsoluteAccuracy) @UnstableKMathAPI -public var MutableList.targetRelativeAccuracy: Double? - get() = filterIsInstance().lastOrNull()?.accuracy - set(value) { - value?.let { add(IntegrandRelativeAccuracy(value)) } - } \ No newline at end of file +public val Attributes.targetRelativeAccuracy: Double? + get() = get(IntegrandRelativeAccuracy) diff --git a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt index d29650e3f..625e7292a 100644 --- a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt +++ b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/linear/CMMatrix.kt @@ -6,18 +6,20 @@ package space.kscience.kmath.commons.linear import org.apache.commons.math3.linear.* +import org.apache.commons.math3.linear.LUDecomposition +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.linear.* +import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.StructureAttribute +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Float64Buffer -import kotlin.reflect.KClass -import kotlin.reflect.KType import kotlin.reflect.cast -import kotlin.reflect.typeOf public class CMMatrix(public val origin: RealMatrix) : Matrix { + override val type: SafeType get() = DoubleField.type override val rowNum: Int get() = origin.rowDimension override val colNum: Int get() = origin.columnDimension @@ -26,6 +28,7 @@ public class CMMatrix(public val origin: RealMatrix) : Matrix { @JvmInline public value class CMVector(public val origin: RealVector) : Point { + override val type: SafeType get() = DoubleField.type override val size: Int get() = origin.dimension override operator fun get(index: Int): Double = origin.getEntry(index) @@ -40,7 +43,7 @@ public fun RealVector.toPoint(): CMVector = CMVector(this) public object CMLinearSpace : LinearSpace { override val elementAlgebra: Float64Field get() = Float64Field - override val elementType: KType = typeOf() + override val type: SafeType get() = DoubleField.type override fun buildMatrix( rows: Int, @@ -102,19 +105,14 @@ public object CMLinearSpace : LinearSpace { override fun Double.times(v: Point): CMVector = v * this - @UnstableKMathAPI - override fun computeFeature(structure: Matrix, type: KClass): F? { - //Return the feature if it is intrinsic to the structure - structure.getFeature(type)?.let { return it } + override fun > computeAttribute(structure: Structure2D, attribute: A): V? { val origin = structure.toCM().origin - return when (type) { - IsDiagonal::class -> if (origin is DiagonalMatrix) IsDiagonal else null - - Determinant::class, LupDecompositionAttribute::class -> object : - Determinant, - LupDecompositionAttribute { + return when (attribute) { + IsDiagonal -> if (origin is DiagonalMatrix) IsDiagonal else null + Determinant -> LUDecomposition(origin).determinant + LUP -> GenericLupDecomposition { private val lup by lazy { LUDecomposition(origin) } override val determinant: Double by lazy { lup.determinant } override val l: Matrix by lazy> { CMMatrix(lup.l).withAttribute(LowerTriangular) } @@ -122,20 +120,24 @@ public object CMLinearSpace : LinearSpace { override val p: Matrix by lazy { CMMatrix(lup.p) } } - CholeskyDecompositionAttribute::class -> object : CholeskyDecompositionAttribute { + CholeskyDecompositionAttribute -> object : CholeskyDecompositionAttribute { override val l: Matrix by lazy> { val cholesky = CholeskyDecomposition(origin) CMMatrix(cholesky.l).withAttribute(LowerTriangular) } } - QRDecompositionAttribute::class -> object : QRDecompositionAttribute { + QRDecompositionAttribute -> object : QRDecompositionAttribute { private val qr by lazy { QRDecomposition(origin) } - override val q: Matrix by lazy> { CMMatrix(qr.q).withAttribute(OrthogonalAttribute) } + override val q: Matrix by lazy> { + CMMatrix(qr.q).withAttribute( + OrthogonalAttribute + ) + } override val r: Matrix by lazy> { CMMatrix(qr.r).withAttribute(UpperTriangular) } } - SVDAttribute::class -> object : SVDAttribute { + SVDAttribute -> object : SVDAttribute { private val sv by lazy { SingularValueDecomposition(origin) } override val u: Matrix by lazy { CMMatrix(sv.u) } override val s: Matrix by lazy { CMMatrix(sv.s) } diff --git a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt index 6e0000721..e834d404b 100644 --- a/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt +++ b/kmath-commons/src/jvmMain/kotlin/space/kscience/kmath/commons/optimization/CMOptimizer.kt @@ -13,6 +13,8 @@ import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer +import space.kscience.attributes.AttributesBuilder +import space.kscience.attributes.SetAttribute import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.SymbolIndexer @@ -26,34 +28,25 @@ import kotlin.reflect.KClass public operator fun PointValuePair.component1(): DoubleArray = point public operator fun PointValuePair.component2(): Double = value -public class CMOptimizerEngine(public val optimizerBuilder: () -> MultivariateOptimizer) : OptimizationFeature { - override fun toString(): String = "CMOptimizer($optimizerBuilder)" -} +public object CMOptimizerEngine: OptimizationAttribute<() -> MultivariateOptimizer> /** * Specify a Commons-maths optimization engine */ -public fun FunctionOptimizationBuilder.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) { - addFeature(CMOptimizerEngine(optimizerBuilder)) +public fun AttributesBuilder>.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) { + set(CMOptimizerEngine, optimizerBuilder) } -public class CMOptimizerData(public val data: List OptimizationData>) : OptimizationFeature { - public constructor(vararg data: (SymbolIndexer.() -> OptimizationData)) : this(data.toList()) - - override fun toString(): String = "CMOptimizerData($data)" -} +public object CMOptimizerData: SetAttribute OptimizationData> /** * Specify Commons-maths optimization data. */ -public fun FunctionOptimizationBuilder.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) { - updateFeature { - val newData = (it?.data ?: emptyList()) + data - CMOptimizerData(newData) - } +public fun AttributesBuilder>.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) { + CMOptimizerData.add(data) } -public fun FunctionOptimizationBuilder.simplexSteps(vararg steps: Pair) { +public fun AttributesBuilder>.simplexSteps(vararg steps: Pair) { //TODO use convergence checker from features cmEngine { SimplexOptimizer(CMOptimizer.defaultConvergenceChecker) } cmOptimizationData { NelderMeadSimplex(mapOf(*steps).toDoubleArray()) } @@ -78,8 +71,8 @@ public object CMOptimizer : Optimizer> { ): FunctionOptimization { val startPoint = problem.startPoint - val parameters = problem.getFeature()?.symbols - ?: problem.getFeature>()?.point?.keys + val parameters = problem.attributes[OptimizationParameters] + ?: problem.attributes[OptimizationStartPoint()]?.keys ?: startPoint.keys @@ -90,7 +83,7 @@ public object CMOptimizer : Optimizer> { DEFAULT_MAX_ITER ) - val cmOptimizer: MultivariateOptimizer = problem.getFeature()?.optimizerBuilder?.invoke() + val cmOptimizer: MultivariateOptimizer = problem.attributes[CMOptimizerEngine]?.invoke() ?: NonLinearConjugateGradientOptimizer( NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, convergenceChecker @@ -123,7 +116,7 @@ public object CMOptimizer : Optimizer> { } addOptimizationData(gradientFunction) - val logger = problem.getFeature() + val logger = problem.attributes[OptimizationLog] for (feature in problem.attributes) { when (feature) { @@ -139,7 +132,7 @@ public object CMOptimizer : Optimizer> { } val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray()) - return problem.withFeatures(OptimizationResult(point.toMap()), OptimizationValue(value)) + return problem.withAttributes(OptimizationResult(point.toMap()), OptimizationValue(value)) } } } diff --git a/kmath-commons/src/jvmTest/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt b/kmath-commons/src/jvmTest/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt index 5933d0d36..5b8c7868a 100644 --- a/kmath-commons/src/jvmTest/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt +++ b/kmath-commons/src/jvmTest/kotlin/space/kscience/kmath/commons/optimization/OptimizeTest.kt @@ -31,7 +31,7 @@ internal class OptimizeTest { @Test fun testGradientOptimization() = runBlocking { val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0) - println(result.resultPoint) + println(result.result) println(result.resultValue) } @@ -42,7 +42,7 @@ internal class OptimizeTest { //this sets simplex optimizer } - println(result.resultPoint) + println(result.result) println(result.resultValue) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt index 8ef751859..3e9f8707e 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/DSAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.expressions +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.* import space.kscience.kmath.structures.Buffer @@ -331,10 +332,13 @@ public class DerivativeStructureRingExpression( public val elementBufferFactory: MutableBufferFactory = algebra.bufferFactory, public val function: DSRing.() -> DS, ) : DifferentiableExpression where A : Ring, A : ScaleOperations, A : NumericAlgebra { + + override val type: SafeType get() = elementBufferFactory.type + override operator fun invoke(arguments: Map): T = DSRing(algebra, 0, arguments).function().value - override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + override fun derivativeOrNull(symbols: List): Expression = Expression(type) { arguments -> with( DSRing( algebra, @@ -443,10 +447,13 @@ public class DSFieldExpression>( public val algebra: A, public val function: DSField.() -> DS, ) : DifferentiableExpression { + + override val type: SafeType get() = algebra.type + override operator fun invoke(arguments: Map): T = DSField(algebra, 0, arguments).function().value - override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + override fun derivativeOrNull(symbols: List): Expression = Expression(type) { arguments -> DSField( algebra, symbols.size, diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt index f350303bc..81ceaae8a 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/Expression.kt @@ -5,8 +5,13 @@ package space.kscience.kmath.expressions +import space.kscience.attributes.SafeType +import space.kscience.attributes.WithType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.Algebra +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.IntRing +import space.kscience.kmath.operations.LongRing import kotlin.jvm.JvmName import kotlin.properties.ReadOnlyProperty @@ -15,7 +20,7 @@ import kotlin.properties.ReadOnlyProperty * * @param T the type this expression takes as argument and returns. */ -public fun interface Expression { +public interface Expression : WithType { /** * Calls this expression from arguments. * @@ -25,11 +30,20 @@ public fun interface Expression { public operator fun invoke(arguments: Map): T } +public fun Expression(type: SafeType, block: (Map) -> T): Expression = object : Expression { + override fun invoke(arguments: Map): T = block(arguments) + + override val type: SafeType = type +} + /** * Specialization of [Expression] for [Double] allowing better performance because of using array. */ @UnstableKMathAPI public interface DoubleExpression : Expression { + + override val type: SafeType get() = DoubleField.type + /** * The indexer of this expression's arguments that should be used to build array for [invoke]. * @@ -49,7 +63,7 @@ public interface DoubleExpression : Expression { */ public operator fun invoke(arguments: DoubleArray): Double - public companion object{ + public companion object { internal val EMPTY_DOUBLE_ARRAY = DoubleArray(0) } } @@ -59,6 +73,9 @@ public interface DoubleExpression : Expression { */ @UnstableKMathAPI public interface IntExpression : Expression { + + override val type: SafeType get() = IntRing.type + /** * The indexer of this expression's arguments that should be used to build array for [invoke]. * @@ -78,7 +95,7 @@ public interface IntExpression : Expression { */ public operator fun invoke(arguments: IntArray): Int - public companion object{ + public companion object { internal val EMPTY_INT_ARRAY = IntArray(0) } } @@ -88,6 +105,9 @@ public interface IntExpression : Expression { */ @UnstableKMathAPI public interface LongExpression : Expression { + + override val type: SafeType get() = LongRing.type + /** * The indexer of this expression's arguments that should be used to build array for [invoke]. * @@ -107,7 +127,7 @@ public interface LongExpression : Expression { */ public operator fun invoke(arguments: LongArray): Long - public companion object{ + public companion object { internal val EMPTY_LONG_ARRAY = LongArray(0) } } @@ -158,7 +178,6 @@ public operator fun Expression.invoke(vararg pairs: Pair): T = ) - /** * Calls this expression without providing any arguments. * diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/ExpressionWithDefault.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/ExpressionWithDefault.kt index c802fe04c..8045817e6 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/ExpressionWithDefault.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/ExpressionWithDefault.kt @@ -5,10 +5,15 @@ package space.kscience.kmath.expressions +import space.kscience.attributes.SafeType + public class ExpressionWithDefault( private val origin: Expression, private val defaultArgs: Map, ) : Expression { + override val type: SafeType + get() = origin.type + override fun invoke(arguments: Map): T = origin.invoke(defaultArgs + arguments) } @@ -21,6 +26,9 @@ public class DiffExpressionWithDefault( private val defaultArgs: Map, ) : DifferentiableExpression { + override val type: SafeType + get() = origin.type + override fun invoke(arguments: Map): T = origin.invoke(defaultArgs + arguments) override fun derivativeOrNull(symbols: List): Expression? = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt index 5b4dcd638..5ff408732 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/FunctionalExpressionAlgebra.kt @@ -23,26 +23,27 @@ public abstract class FunctionalExpressionAlgebra>( /** * Builds an Expression of constant expression that does not depend on arguments. */ - override fun const(value: T): Expression = Expression { value } + override fun const(value: T): Expression = Expression(algebra.type) { value } /** * Builds an Expression to access a variable. */ - override fun bindSymbolOrNull(value: String): Expression? = Expression { arguments -> + override fun bindSymbolOrNull(value: String): Expression? = Expression(algebra.type) { arguments -> algebra.bindSymbolOrNull(value) ?: arguments[StringSymbol(value)] ?: error("Symbol '$value' is not supported in $this") } - override fun binaryOperationFunction(operation: String): (left: Expression, right: Expression) -> Expression = - { left, right -> - Expression { arguments -> - algebra.binaryOperationFunction(operation)(left(arguments), right(arguments)) - } + override fun binaryOperationFunction( + operation: String, + ): (left: Expression, right: Expression) -> Expression = { left, right -> + Expression(algebra.type) { arguments -> + algebra.binaryOperationFunction(operation)(left(arguments), right(arguments)) } + } override fun unaryOperationFunction(operation: String): (arg: Expression) -> Expression = { arg -> - Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) } + Expression(algebra.type) { arguments -> algebra.unaryOperation(operation, arg(arguments)) } } } @@ -124,7 +125,7 @@ public open class FunctionalExpressionField>( super.binaryOperationFunction(operation) override fun scale(a: Expression, value: Double): Expression = algebra { - Expression { args -> a(args) * value } + Expression(algebra.type) { args -> a(args) * value } } override fun bindSymbolOrNull(value: String): Expression? = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt index 9705a3f03..7ceb23b37 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/MST.kt @@ -108,4 +108,4 @@ public fun MST.interpret(algebra: Algebra, vararg arguments: Pair MST.toExpression(algebra: Algebra): Expression = - Expression { arguments -> interpret(algebra, arguments) } + Expression(algebra.type) { arguments -> interpret(algebra, arguments) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt index fd7bf9fdc..7a67e12cc 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/expressions/SimpleAutoDiff.kt @@ -243,12 +243,15 @@ public class SimpleAutoDiffExpression>( public val field: F, public val function: SimpleAutoDiffField.() -> AutoDiffValue, ) : FirstDerivativeExpression() { + + override val type: SafeType get() = this.field.type + override operator fun invoke(arguments: Map): T { //val bindings = arguments.entries.map { it.key.bind(it.value) } return SimpleAutoDiffField(field, arguments).function().value } - override fun derivativeOrNull(symbol: Symbol): Expression = Expression { arguments -> + override fun derivativeOrNull(symbol: Symbol): Expression = Expression(type) { arguments -> //val bindings = arguments.entries.map { it.key.bind(it.value) } val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function) derivationResult.derivative(symbol) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt index af9ebb463..93d84025f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSolver.kt @@ -11,7 +11,7 @@ package space.kscience.kmath.linear * * @param T the type of items. */ -public interface LinearSolver { +public interface LinearSolver { /** * Solve a dot x = b matrix equation and return x */ diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt index 6b547e2c5..789164e02 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LinearSpace.kt @@ -5,10 +5,7 @@ package space.kscience.kmath.linear -import space.kscience.attributes.Attributes -import space.kscience.attributes.SafeType -import space.kscience.attributes.WithType -import space.kscience.attributes.withAttribute +import space.kscience.attributes.* import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.BufferRingOps @@ -31,19 +28,13 @@ public typealias MutableMatrix = MutableStructure2D */ public typealias Point = Buffer -/** - * A marker interface for algebras that operate on matrices - * @param T type of matrix element - */ -public interface MatrixOperations : WithType - /** * Basic operations on matrices and vectors. * * @param T the type of items in the matrices. * @param A the type of ring over [T]. */ -public interface LinearSpace> : MatrixOperations { +public interface LinearSpace> : MatrixScope { public val elementAlgebra: A override val type: SafeType get() = elementAlgebra.type @@ -177,10 +168,10 @@ public interface LinearSpace> : MatrixOperations { /** * Compute an [attribute] value for given [structure]. Return null if the attribute could not be computed. */ - public fun > computeAttribute(structure: StructureND<*>, attribute: A): V? = null + public fun > computeAttribute(structure: Structure2D, attribute: A): V? = null @UnstableKMathAPI - public fun > StructureND<*>.getOrComputeAttribute(attribute: A): V? { + public fun > Structure2D.getOrComputeAttribute(attribute: A): V? { return attributes[attribute] ?: computeAttribute(this, attribute) } @@ -225,7 +216,7 @@ public inline operator fun , R> LS.invoke(block: LS.() -> /** * Convert matrix to vector if it is possible. */ -public fun Matrix.asVector(): Point = +public fun Matrix.asVector(): Point = if (this.colNum == 1) as1D() else error("Can't convert matrix with more than one column to vector") @@ -236,4 +227,4 @@ public fun Matrix.asVector(): Point = * @receiver a buffer. * @return the new matrix. */ -public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(type, size, 1) { i, _ -> get(i) } \ No newline at end of file +public fun Point.asMatrix(): VirtualMatrix = VirtualMatrix(type, size, 1) { i, _ -> get(i) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt index fab4ef3db..bf595d332 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/LupDecomposition.kt @@ -9,12 +9,20 @@ package space.kscience.kmath.linear import space.kscience.attributes.Attributes import space.kscience.attributes.PolymorphicAttribute -import space.kscience.attributes.SafeType import space.kscience.attributes.safeTypeOf import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.* import space.kscience.kmath.structures.* +public interface LupDecomposition { + public val linearSpace: LinearSpace> + public val elementAlgebra: Field get() = linearSpace.elementAlgebra + + public val pivot: IntBuffer + public val l: Matrix + public val u: Matrix +} + /** * Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where * *a* is the owning matrix. @@ -22,15 +30,14 @@ import space.kscience.kmath.structures.* * @param T the type of matrices' items. * @param lu combined L and U matrix */ -public class LupDecomposition( - public val linearSpace: LinearSpace>, +public class GenericLupDecomposition( + override val linearSpace: LinearSpace>, private val lu: Matrix, - public val pivot: IntBuffer, + override val pivot: IntBuffer, private val even: Boolean, -) { - public val elementAlgebra: Ring get() = linearSpace.elementAlgebra +) : LupDecomposition { - public val l: Matrix + override val l: Matrix get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(LowerTriangular)) { i, j -> when { j < i -> lu[i, j] @@ -39,7 +46,7 @@ public class LupDecomposition( } } - public val u: Matrix + override val u: Matrix get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(UpperTriangular)) { i, j -> if (j >= i) lu[i, j] else elementAlgebra.zero } @@ -55,13 +62,12 @@ public class LupDecomposition( } - -public class LupDecompositionAttribute(type: SafeType>) : - PolymorphicAttribute>(type), +public class LupDecompositionAttribute : + PolymorphicAttribute>(safeTypeOf()), MatrixAttribute> -public val MatrixOperations.LUP: LupDecompositionAttribute - get() = LupDecompositionAttribute(safeTypeOf()) +public val MatrixScope.LUP: LupDecompositionAttribute + get() = LupDecompositionAttribute() @PublishedApi internal fun > LinearSpace>.abs(value: T): T = @@ -79,7 +85,7 @@ public fun > LinearSpace>.lup( val pivot = IntArray(matrix.rowNum) //TODO just waits for multi-receivers - with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)){ + with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) { val lu = create(matrix) @@ -142,18 +148,17 @@ public fun > LinearSpace>.lup( for (row in col + 1 until m) lu[row, col] /= luDiag } - return LupDecomposition(this@lup, lu.toStructure2D(), pivot.asBuffer(), even) + return GenericLupDecomposition(this@lup, lu.toStructure2D(), pivot.asBuffer(), even) } } - public fun LinearSpace.lup( matrix: Matrix, singularityThreshold: Double = 1e-11, ): LupDecomposition = lup(matrix) { it < singularityThreshold } -internal fun > LinearSpace.solve( +internal fun LinearSpace>.solve( lup: LupDecomposition, matrix: Matrix, ): Matrix { @@ -205,7 +210,7 @@ internal fun > LinearSpace.solve( * Produce a generic solver based on LUP decomposition */ @OptIn(UnstableKMathAPI::class) -public fun , F : Field> LinearSpace.lupSolver( +public fun > LinearSpace>.lupSolver( singularityCheck: (T) -> Boolean, ): LinearSolver = object : LinearSolver { override fun solve(a: Matrix, b: Matrix): Matrix { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt index eb14de3c7..4707f6cfe 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/MatrixWrapper.kt @@ -35,7 +35,7 @@ public val Matrix.origin: Matrix /** * Add a single feature to a [Matrix] */ -public fun > Matrix.withAttribute( +public fun > Matrix.withAttribute( attribute: A, attrValue: T, ): MatrixWrapper = if (this is MatrixWrapper) { @@ -44,7 +44,7 @@ public fun > Matrix.withAttribute( MatrixWrapper(this, Attributes(attribute, attrValue)) } -public fun > Matrix.withAttribute( +public fun > Matrix.withAttribute( attribute: A, ): MatrixWrapper = if (this is MatrixWrapper) { MatrixWrapper(origin, attributes.withAttribute(attribute)) @@ -55,7 +55,7 @@ public fun > Matrix.withAttribute( /** * Modify matrix attributes */ -public fun Matrix.modifyAttributes(modifier: (Attributes) -> Attributes): MatrixWrapper = +public fun Matrix.modifyAttributes(modifier: (Attributes) -> Attributes): MatrixWrapper = if (this is MatrixWrapper) { MatrixWrapper(origin, modifier(attributes)) } else { @@ -65,7 +65,7 @@ public fun Matrix.modifyAttributes(modifier: (Attributes) -> Attrib /** * Diagonal matrix of ones. The matrix is virtual, no actual matrix is created. */ -public fun LinearSpace>.one( +public fun LinearSpace>.one( rows: Int, columns: Int, ): MatrixWrapper = VirtualMatrix(type, rows, columns) { i, j -> @@ -76,7 +76,7 @@ public fun LinearSpace>.one( /** * A virtual matrix of zeroes */ -public fun LinearSpace>.zero( +public fun LinearSpace>.zero( rows: Int, columns: Int, ): MatrixWrapper = VirtualMatrix(type, rows, columns) { _, _ -> diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt index 8787d0e09..dadb2f3d7 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/linear/matrixAttributes.kt @@ -10,6 +10,13 @@ package space.kscience.kmath.linear import space.kscience.attributes.* import space.kscience.kmath.nd.StructureAttribute + +/** + * A marker interface for algebras that operate on matrices + * @param T type of matrix element + */ +public interface MatrixScope : AttributeScope>, WithType + /** * A marker interface representing some properties of matrices or additional transformations of them. Features are used * to optimize matrix operations performance in some cases or retrieve the APIs. @@ -38,11 +45,11 @@ public object IsUnit : IsDiagonal * * @param T the type of matrices' items. */ -public class Inverted(type: SafeType>) : - PolymorphicAttribute>(type), +public class Inverted() : + PolymorphicAttribute>(safeTypeOf()), MatrixAttribute> -public val MatrixOperations.Inverted: Inverted get() = Inverted(safeTypeOf()) +public val MatrixScope.Inverted: Inverted get() = Inverted() /** * Matrices with this feature can compute their determinant. @@ -53,7 +60,7 @@ public class Determinant(type: SafeType) : PolymorphicAttribute(type), MatrixAttribute -public val MatrixOperations.Determinant: Determinant get() = Determinant(type) +public val MatrixScope.Determinant: Determinant get() = Determinant(type) /** * Matrices with this feature are lower triangular ones. @@ -77,11 +84,11 @@ public data class LUDecomposition(val l: Matrix, val u: Matrix) * * @param T the type of matrices' items. */ -public class LuDecompositionAttribute(type: SafeType>) : - PolymorphicAttribute>(type), +public class LuDecompositionAttribute : + PolymorphicAttribute>(safeTypeOf()), MatrixAttribute> -public val MatrixOperations.LU: LuDecompositionAttribute get() = LuDecompositionAttribute(safeTypeOf()) +public val MatrixScope.LU: LuDecompositionAttribute get() = LuDecompositionAttribute() /** @@ -108,12 +115,12 @@ public interface QRDecomposition { * * @param T the type of matrices' items. */ -public class QRDecompositionAttribute(type: SafeType>) : - PolymorphicAttribute>(type), +public class QRDecompositionAttribute() : + PolymorphicAttribute>(safeTypeOf()), MatrixAttribute> -public val MatrixOperations.QR: QRDecompositionAttribute - get() = QRDecompositionAttribute(safeTypeOf()) +public val MatrixScope.QR: QRDecompositionAttribute + get() = QRDecompositionAttribute() public interface CholeskyDecomposition { /** @@ -128,12 +135,12 @@ public interface CholeskyDecomposition { * * @param T the type of matrices' items. */ -public class CholeskyDecompositionAttribute(type: SafeType>) : - PolymorphicAttribute>(type), +public class CholeskyDecompositionAttribute : + PolymorphicAttribute>(safeTypeOf()), MatrixAttribute> -public val MatrixOperations.Cholesky: CholeskyDecompositionAttribute - get() = CholeskyDecompositionAttribute(safeTypeOf()) +public val MatrixScope.Cholesky: CholeskyDecompositionAttribute + get() = CholeskyDecompositionAttribute() public interface SingularValueDecomposition { /** @@ -163,12 +170,11 @@ public interface SingularValueDecomposition { * * @param T the type of matrices' items. */ -public class SVDAttribute(type: SafeType>) : - PolymorphicAttribute>(type), +public class SVDAttribute() : + PolymorphicAttribute>(safeTypeOf()), MatrixAttribute> -public val MatrixOperations.SVD: SVDAttribute - get() = SVDAttribute(safeTypeOf()) +public val MatrixScope.SVD: SVDAttribute get() = SVDAttribute() //TODO add sparse matrix feature diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt index 37b14f8d1..47a4b686f 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/EjmlLinearSpace.kt @@ -39,5 +39,5 @@ public abstract class EjmlLinearSpace, out M : org.ejml @UnstableKMathAPI public fun EjmlMatrix.inverted(): Matrix = - attributeForOrNull(this, Float64Field.linearSpace.Inverted) + computeAttribute(this, Float64Field.linearSpace.Inverted)!! } diff --git a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt index 881649d01..882df9536 100644 --- a/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt +++ b/kmath-ejml/src/main/kotlin/space/kscience/kmath/ejml/_generated.kt @@ -19,9 +19,13 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.linear.* import space.kscience.kmath.linear.Matrix +import space.kscience.kmath.nd.Structure2D +import space.kscience.kmath.nd.StructureAttribute import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.operations.Float32Field import space.kscience.kmath.operations.Float64Field @@ -39,6 +43,8 @@ public class EjmlDoubleVector(override val origin: M) : EjmlVec require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } } + override val type: SafeType get() = safeTypeOf() + override operator fun get(index: Int): Double = origin[0, index] } @@ -50,6 +56,8 @@ public class EjmlFloatVector(override val origin: M) : EjmlVect require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } } + override val type: SafeType get() = safeTypeOf() + override operator fun get(index: Int): Float = origin[0, index] } @@ -57,6 +65,8 @@ public class EjmlFloatVector(override val origin: M) : EjmlVect * [EjmlMatrix] specialization for [Double]. */ public class EjmlDoubleMatrix(override val origin: M) : EjmlMatrix(origin) { + override val type: SafeType get() = safeTypeOf() + override operator fun get(i: Int, j: Int): Double = origin[i, j] } @@ -64,9 +74,12 @@ public class EjmlDoubleMatrix(override val origin: M) : EjmlMat * [EjmlMatrix] specialization for [Float]. */ public class EjmlFloatMatrix(override val origin: M) : EjmlMatrix(origin) { + override val type: SafeType get() = safeTypeOf() + override operator fun get(i: Int, j: Int): Float = origin[i, j] } + /** * [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and * [DMatrixRMaj] matrices. @@ -77,7 +90,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace() + override val type: SafeType get() = safeTypeOf() @Suppress("UNCHECKED_CAST") override fun Matrix.toEjml(): EjmlDoubleMatrix = when { @@ -205,6 +218,18 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace): EjmlDoubleVector = v * this + override fun > computeAttribute(structure: Structure2D, attribute: A): V? { + val origin = structure.toEjml().origin + return when(attribute){ + Inverted -> { + val res = origin.copy() + CommonOps_DDRM.invert(res) + res.wrapMatrix() + } + else-> + } + } + @UnstableKMathAPI override fun computeFeature(structure: Matrix, type: KClass): F? { structure.getFeature(type)?.let { return it } @@ -305,6 +330,8 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace() + override val type: SafeType get() = safeTypeOf() @Suppress("UNCHECKED_CAST") override fun Matrix.toEjml(): EjmlFloatMatrix = when { @@ -543,6 +570,8 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace() + override val type: SafeType get() = safeTypeOf() @Suppress("UNCHECKED_CAST") override fun Matrix.toEjml(): EjmlDoubleMatrix = when { @@ -776,6 +805,8 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace() + override val type: SafeType get() = safeTypeOf() @Suppress("UNCHECKED_CAST") override fun Matrix.toEjml(): EjmlFloatMatrix = when { diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt index 7d37711cd..16b7a90f8 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt @@ -4,7 +4,7 @@ */ package space.kscience.kmath.integration -import space.kscience.attributes.TypedAttributesBuilder +import space.kscience.attributes.AttributesBuilder import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.operations.Field import space.kscience.kmath.structures.Buffer @@ -92,7 +92,7 @@ public inline fun GaussIntegrator.integrate( range: ClosedRange, order: Int = 10, intervals: Int = 10, - attributesBuilder: TypedAttributesBuilder>.() -> Unit, + attributesBuilder: AttributesBuilder>.() -> Unit, noinline function: (Double) -> T, ): UnivariateIntegrand { require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt index d465e87c7..d8f102b2f 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/Integrand.kt @@ -30,7 +30,7 @@ public sealed class IntegrandValue private constructor(): IntegrandAttribute< } } -public fun TypedAttributesBuilder>.value(value: T) { +public fun AttributesBuilder>.value(value: T) { IntegrandValue.forType().invoke(value) } diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt index 2081947ec..5c8a21f61 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/MultivariateIntegrand.kt @@ -25,10 +25,10 @@ public fun MultivariateIntegrand.withAttribute( ): MultivariateIntegrand = withAttributes(attributes.withAttribute(attribute, value)) public fun MultivariateIntegrand.withAttributes( - block: TypedAttributesBuilder>.() -> Unit, + block: AttributesBuilder>.() -> Unit, ): MultivariateIntegrand = withAttributes(attributes.modify(block)) public inline fun MultivariateIntegrand( - attributeBuilder: TypedAttributesBuilder>.() -> Unit, + attributeBuilder: AttributesBuilder>.() -> Unit, noinline function: (Point) -> T, ): MultivariateIntegrand = MultivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function) diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt index 9c39e7edc..05a765858 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt @@ -26,11 +26,11 @@ public fun UnivariateIntegrand.withAttribute( ): UnivariateIntegrand = withAttributes(attributes.withAttribute(attribute, value)) public fun UnivariateIntegrand.withAttributes( - block: TypedAttributesBuilder>.() -> Unit, + block: AttributesBuilder>.() -> Unit, ): UnivariateIntegrand = withAttributes(attributes.modify(block)) public inline fun UnivariateIntegrand( - attributeBuilder: TypedAttributesBuilder>.() -> Unit, + attributeBuilder: AttributesBuilder>.() -> Unit, noinline function: (Double) -> T, ): UnivariateIntegrand = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function) @@ -58,7 +58,7 @@ public class UnivariateIntegrandRanges(public val ranges: List> -public fun TypedAttributesBuilder>.integrationNodes(vararg nodes: Double) { +public fun AttributesBuilder>.integrationNodes(vararg nodes: Double) { UnivariateIntegrationNodes(Float64Buffer(nodes)) } @@ -68,7 +68,7 @@ public fun TypedAttributesBuilder>.integrationNodes(varar */ @UnstableKMathAPI public inline fun UnivariateIntegrator.integrate( - attributesBuilder: TypedAttributesBuilder>.() -> Unit, + attributesBuilder: AttributesBuilder>.() -> Unit, noinline function: (Double) -> T, ): UnivariateIntegrand = integrate(UnivariateIntegrand(attributesBuilder, function)) @@ -79,7 +79,7 @@ public inline fun UnivariateIntegrator.integrate( @UnstableKMathAPI public inline fun UnivariateIntegrator.integrate( range: ClosedRange, - attributeBuilder: TypedAttributesBuilder>.() -> Unit = {}, + attributeBuilder: AttributesBuilder>.() -> Unit = {}, noinline function: (Double) -> T, ): UnivariateIntegrand { diff --git a/kmath-jafama/build.gradle.kts b/kmath-jafama/build.gradle.kts index 5a77a97ed..0390224ba 100644 --- a/kmath-jafama/build.gradle.kts +++ b/kmath-jafama/build.gradle.kts @@ -14,7 +14,7 @@ repositories { } readme { - maturity = space.kscience.gradle.Maturity.PROTOTYPE + maturity = space.kscience.gradle.Maturity.DEPRECATED propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) feature("jafama-double", "src/main/kotlin/space/kscience/kmath/jafama/") { diff --git a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt index f9b8287b4..1c52456f3 100644 --- a/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt +++ b/kmath-jafama/src/main/kotlin/space/kscience/kmath/jafama/KMathJafama.kt @@ -7,16 +7,17 @@ package space.kscience.kmath.jafama import net.jafama.FastMath import net.jafama.StrictFastMath -import space.kscience.kmath.operations.ExtendedField -import space.kscience.kmath.operations.Norm -import space.kscience.kmath.operations.PowerOperations -import space.kscience.kmath.operations.ScaleOperations +import space.kscience.kmath.operations.* +import space.kscience.kmath.structures.MutableBufferFactory /** * A field for [Double] (using FastMath) without boxing. Does not produce appropriate field element. */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object JafamaDoubleField : ExtendedField, Norm, ScaleOperations { + + override val bufferFactory: MutableBufferFactory get() = DoubleField.bufferFactory + override inline val zero: Double get() = 0.0 override inline val one: Double get() = 1.0 @@ -68,6 +69,9 @@ public object JafamaDoubleField : ExtendedField, Norm, S */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") public object StrictJafamaDoubleField : ExtendedField, Norm, ScaleOperations { + + override val bufferFactory: MutableBufferFactory get() = DoubleField.bufferFactory + override inline val zero: Double get() = 0.0 override inline val one: Double get() = 1.0 diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt index 8b463a230..60413f88c 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt @@ -20,7 +20,7 @@ public class MultikDoubleAlgebra( ) : MultikDivisionTensorAlgebra(multikEngine), TrigonometricOperations>, ExponentialOperations> { override val elementAlgebra: Float64Field get() = Float64Field - override val type: DataType get() = DataType.DoubleDataType + override val dataType: DataType get() = DataType.DoubleDataType override fun sin(arg: StructureND): MultikTensor = multikMath.mathEx.sin(arg.asMultik().array).wrap() diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt index 7a3dda94b..26331bd6b 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt @@ -15,7 +15,7 @@ public class MultikFloatAlgebra( multikEngine: Engine ) : MultikDivisionTensorAlgebra(multikEngine) { override val elementAlgebra: Float32Field get() = Float32Field - override val type: DataType get() = DataType.FloatDataType + override val dataType: DataType get() = DataType.FloatDataType override fun scalar(value: Float): MultikTensor = Multik.ndarrayOf(value).wrap() } diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt index 5bd1b3388..46acbdf9d 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt @@ -15,7 +15,7 @@ public class MultikIntAlgebra( multikEngine: Engine ) : MultikTensorAlgebra(multikEngine) { override val elementAlgebra: Int32Ring get() = Int32Ring - override val type: DataType get() = DataType.IntDataType + override val dataType: DataType get() = DataType.IntDataType override fun scalar(value: Int): MultikTensor = Multik.ndarrayOf(value).wrap() } diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt index 69a8ec042..97e86d86a 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt @@ -15,7 +15,7 @@ public class MultikLongAlgebra( multikEngine: Engine ) : MultikTensorAlgebra(multikEngine) { override val elementAlgebra: Int64Ring get() = Int64Ring - override val type: DataType get() = DataType.LongDataType + override val dataType: DataType get() = DataType.LongDataType override fun scalar(value: Long): MultikTensor = Multik.ndarrayOf(value).wrap() } diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt index 7c8740665..27d43f7b8 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt @@ -15,7 +15,7 @@ public class MultikShortAlgebra( multikEngine: Engine ) : MultikTensorAlgebra(multikEngine) { override val elementAlgebra: Int16Ring get() = Int16Ring - override val type: DataType get() = DataType.ShortDataType + override val dataType: DataType get() = DataType.ShortDataType override fun scalar(value: Short): MultikTensor = Multik.ndarrayOf(value).wrap() } diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt index 59a9a1bf3..5ed8ea767 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt @@ -6,13 +6,33 @@ package space.kscience.kmath.multik import org.jetbrains.kotlinx.multik.ndarray.data.* +import space.kscience.attributes.SafeType +import space.kscience.attributes.safeTypeOf import space.kscience.kmath.PerformancePitfall +import space.kscience.kmath.complex.ComplexField import space.kscience.kmath.nd.ShapeND +import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.api.Tensor import kotlin.jvm.JvmInline +public val DataType.type: SafeType<*> + get() = when (this) { + DataType.ByteDataType -> ByteRing.type + DataType.ShortDataType -> ShortRing.type + DataType.IntDataType -> IntRing.type + DataType.LongDataType -> LongRing.type + DataType.FloatDataType -> Float32Field.type + DataType.DoubleDataType -> Float64Field.type + DataType.ComplexFloatDataType -> safeTypeOf>() + DataType.ComplexDoubleDataType -> ComplexField.type + } + + @JvmInline public value class MultikTensor(public val array: MutableMultiArray) : Tensor { + @Suppress("UNCHECKED_CAST") + override val type: SafeType get() = array.dtype.type as SafeType + override val shape: ShapeND get() = ShapeND(array.shape) @PerformancePitfall diff --git a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index c5bbebfd8..468f1652d 100644 --- a/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -26,7 +26,7 @@ public abstract class MultikTensorAlgebra>( private val multikEngine: Engine, ) : TensorAlgebra where T : Number, T : Comparable { - public abstract val type: DataType + public abstract val dataType: DataType protected val multikMath: Math = multikEngine.getMath() protected val multikLinAl: LinAlg = multikEngine.getLinAlg() @@ -35,7 +35,7 @@ public abstract class MultikTensorAlgebra>( @OptIn(UnsafeKMathAPI::class) override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor { val strides = ColumnStrides(shape) - val memoryView = initMemoryView(strides.linearSize, type) + val memoryView = initMemoryView(strides.linearSize, dataType) strides.asSequence().forEachIndexed { linearIndex, tensorIndex -> memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex) } @@ -44,7 +44,7 @@ public abstract class MultikTensorAlgebra>( @OptIn(PerformancePitfall::class, UnsafeKMathAPI::class) override fun StructureND.map(transform: A.(T) -> T): MultikTensor = if (this is MultikTensor) { - val data = initMemoryView(array.size, type) + val data = initMemoryView(array.size, dataType) var count = 0 for (el in array) data[count++] = elementAlgebra.transform(el) NDArray(data, shape = shape.asArray(), dim = array.dim).wrap() @@ -58,7 +58,7 @@ public abstract class MultikTensorAlgebra>( override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor = if (this is MultikTensor) { val array = asMultik().array - val data = initMemoryView(array.size, type) + val data = initMemoryView(array.size, dataType) val indexIter = array.multiIndices.iterator() var index = 0 for (item in array) { @@ -95,7 +95,7 @@ public abstract class MultikTensorAlgebra>( require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException val leftArray = left.asMultik().array val rightArray = right.asMultik().array - val data = initMemoryView(leftArray.size, type) + val data = initMemoryView(leftArray.size, dataType) var counter = 0 val leftIterator = leftArray.iterator() val rightIterator = rightArray.iterator() @@ -114,7 +114,7 @@ public abstract class MultikTensorAlgebra>( public fun StructureND.asMultik(): MultikTensor = if (this is MultikTensor) { this } else { - val res = mk.zeros(shape.asArray(), type).asDNArray() + val res = mk.zeros(shape.asArray(), dataType).asDNArray() for (index in res.multiIndices) { res[index] = this[index] } @@ -296,7 +296,7 @@ public abstract class MultikDivisionTensorAlgebra>( @OptIn(UnsafeKMathAPI::class) override fun T.div(arg: StructureND): MultikTensor = - Multik.ones(arg.shape.asArray(), type).apply { divAssign(arg.asMultik().array) }.wrap() + Multik.ones(arg.shape.asArray(), dataType).apply { divAssign(arg.asMultik().array) }.wrap() override fun StructureND.div(arg: T): MultikTensor = asMultik().array.div(arg).wrap() diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt index 8025428e6..527c8abae 100644 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt @@ -9,20 +9,21 @@ import space.kscience.attributes.* import space.kscience.kmath.expressions.DifferentiableExpression import space.kscience.kmath.expressions.Symbol -public class OptimizationValue(public val value: T) : OptimizationFeature { - override fun toString(): String = "Value($value)" -} +public class OptimizationValue(type: SafeType) : PolymorphicAttribute(type) -public enum class FunctionOptimizationTarget { +public enum class OptimizationDirection { MAXIMIZE, MINIMIZE } +public object FunctionOptimizationTarget: OptimizationAttribute + public class FunctionOptimization( - override val attributes: Attributes, public val expression: DifferentiableExpression, + override val attributes: Attributes, ) : OptimizationProblem { + override val type: SafeType get() = expression.type override fun equals(other: Any?): Boolean { if (this === other) return true @@ -47,36 +48,52 @@ public class FunctionOptimization( public companion object } +public fun FunctionOptimization( + expression: DifferentiableExpression, + attributeBuilder: AttributesBuilder>.() -> Unit, +): FunctionOptimization = FunctionOptimization(expression, Attributes(attributeBuilder)) - -public class OptimizationPrior(type: SafeType): +public class OptimizationPrior : PolymorphicAttribute>(safeTypeOf()), Attribute> -//public val FunctionOptimization.Companion.Optimization get() = - - -public fun FunctionOptimization.withFeatures( - vararg newFeature: OptimizationFeature, +public fun FunctionOptimization.withAttributes( + modifier: AttributesBuilder>.() -> Unit, ): FunctionOptimization = FunctionOptimization( - attributes.with(*newFeature), expression, + attributes.modify(modifier), ) /** * Optimizes differentiable expression using specific [optimizer] form given [startingPoint]. */ -public suspend fun DifferentiableExpression.optimizeWith( +public suspend fun DifferentiableExpression.optimizeWith( optimizer: Optimizer>, startingPoint: Map, - vararg features: OptimizationFeature, + modifier: AttributesBuilder>.() -> Unit = {}, ): FunctionOptimization { - val problem = FunctionOptimization(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this) + val problem = FunctionOptimization(this){ + startAt(startingPoint) + modifier() + } return optimizer.optimize(problem) } public val FunctionOptimization.resultValueOrNull: T? - get() = getFeature>()?.point?.let { expression(it) } + get() = attributes[OptimizationResult()]?.let { expression(it) } public val FunctionOptimization.resultValue: T - get() = resultValueOrNull ?: error("Result is not present in $this") \ No newline at end of file + get() = resultValueOrNull ?: error("Result is not present in $this") + + +public suspend fun DifferentiableExpression.optimizeWith( + optimizer: Optimizer>, + vararg startingPoint: Pair, + builder: AttributesBuilder>.() -> Unit = {}, +): FunctionOptimization { + val problem = FunctionOptimization(this) { + startAt(mapOf(*startingPoint)) + builder() + } + return optimizer.optimize(problem) +} diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt deleted file mode 100644 index 0459d46ee..000000000 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationBuilder.kt +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright 2018-2022 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. - */ - -package space.kscience.kmath.optimization - -import space.kscience.kmath.UnstableKMathAPI -import space.kscience.kmath.data.XYColumnarData -import space.kscience.kmath.expressions.DifferentiableExpression -import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.misc.FeatureSet - -public abstract class OptimizationBuilder> { - public val features: MutableList = ArrayList() - - public fun addFeature(feature: OptimizationFeature) { - features.add(feature) - } - - public inline fun updateFeature(update: (T?) -> T) { - val existing = features.find { it.key == T::class } as? T - val new = update(existing) - if (existing != null) { - features.remove(existing) - } - addFeature(new) - } - - public abstract fun build(): R -} - -public fun OptimizationBuilder.startAt(startingPoint: Map) { - addFeature(OptimizationStartPoint(startingPoint)) -} - -public class FunctionOptimizationBuilder( - private val expression: DifferentiableExpression, -) : OptimizationBuilder>() { - override fun build(): FunctionOptimization = FunctionOptimization(FeatureSet.of(features), expression) -} - -public fun FunctionOptimization( - expression: DifferentiableExpression, - builder: FunctionOptimizationBuilder.() -> Unit, -): FunctionOptimization = FunctionOptimizationBuilder(expression).apply(builder).build() - -public suspend fun DifferentiableExpression.optimizeWith( - optimizer: Optimizer>, - startingPoint: Map, - builder: FunctionOptimizationBuilder.() -> Unit = {}, -): FunctionOptimization { - val problem = FunctionOptimization(this) { - startAt(startingPoint) - builder() - } - return optimizer.optimize(problem) -} - -public suspend fun DifferentiableExpression.optimizeWith( - optimizer: Optimizer>, - vararg startingPoint: Pair, - builder: FunctionOptimizationBuilder.() -> Unit = {}, -): FunctionOptimization { - val problem = FunctionOptimization(this) { - startAt(mapOf(*startingPoint)) - builder() - } - return optimizer.optimize(problem) -} - - -@OptIn(UnstableKMathAPI::class) -public class XYOptimizationBuilder( - public val data: XYColumnarData, - public val model: DifferentiableExpression, -) : OptimizationBuilder() { - - public var pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY - public var pointWeight: PointWeight = PointWeight.byYSigma - - override fun build(): XYFit = XYFit( - data, - model, - FeatureSet.of(features), - pointToCurveDistance, - pointWeight - ) -} - -@OptIn(UnstableKMathAPI::class) -public fun XYOptimization( - data: XYColumnarData, - model: DifferentiableExpression, - builder: XYOptimizationBuilder.() -> Unit, -): XYFit = XYOptimizationBuilder(data, model).apply(builder).build() \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt index 46ba8c1c0..49f3bb266 100644 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt @@ -6,64 +6,53 @@ package space.kscience.kmath.optimization import space.kscience.attributes.* -import space.kscience.kmath.expressions.DifferentiableExpression import space.kscience.kmath.expressions.NamedMatrix import space.kscience.kmath.expressions.Symbol -import space.kscience.kmath.misc.* -import kotlin.reflect.KClass +import space.kscience.kmath.misc.Loggable -public interface OptimizationAttribute: Attribute +public interface OptimizationAttribute : Attribute -public interface OptimizationProblem : AttributeContainer +public interface OptimizationProblem : AttributeContainer, WithType -public inline fun OptimizationProblem<*>.getFeature(): F? = getFeature(F::class) - -public open class OptimizationStartPoint(public val point: Map) : OptimizationFeature { - override fun toString(): String = "StartPoint($point)" -} - -/** - * Covariance matrix for - */ -public class OptimizationCovariance(public val covariance: NamedMatrix) : OptimizationFeature { - override fun toString(): String = "Covariance($covariance)" -} +public class OptimizationStartPoint : OptimizationAttribute>, + PolymorphicAttribute>(safeTypeOf()) /** * Get the starting point for optimization. Throws error if not defined. */ public val OptimizationProblem.startPoint: Map - get() = getFeature>()?.point - ?: error("Starting point not defined in $this") + get() = attributes[OptimizationStartPoint()] ?: error("Starting point not defined in $this") -public open class OptimizationResult(public val point: Map) : OptimizationFeature { - override fun toString(): String = "Result($point)" +public fun AttributesBuilder>.startAt(startingPoint: Map) { + set(::OptimizationStartPoint, startingPoint) } -public val OptimizationProblem.resultPointOrNull: Map? - get() = getFeature>()?.point -public val OptimizationProblem.resultPoint: Map - get() = resultPointOrNull ?: error("Result is not present in $this") +/** + * Covariance matrix for optimization + */ +public class OptimizationCovariance : OptimizationAttribute>, + PolymorphicAttribute>(safeTypeOf()) -public class OptimizationLog(private val loggable: Loggable) : Loggable by loggable, OptimizationFeature { - override fun toString(): String = "Log($loggable)" -} + +public class OptimizationResult() : OptimizationAttribute>, + PolymorphicAttribute>(safeTypeOf()) + +public val OptimizationProblem.resultOrNull: Map? get() = attributes[OptimizationResult()] + +public val OptimizationProblem.result: Map + get() = resultOrNull ?: error("Result is not present in $this") + +public object OptimizationLog : OptimizationAttribute /** * Free parameters of the optimization */ -public class OptimizationParameters(public val symbols: List) : OptimizationFeature { - public constructor(vararg symbols: Symbol) : this(listOf(*symbols)) - - override fun toString(): String = "Parameters($symbols)" -} +public object OptimizationParameters : OptimizationAttribute> /** * Maximum allowed number of iterations */ -public class OptimizationIterations(public val maxIterations: Int) : OptimizationFeature { - override fun toString(): String = "Iterations($maxIterations)" -} +public object OptimizationIterations : OptimizationAttribute diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt index e922fd423..9b715f95d 100644 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/QowOptimizer.kt @@ -16,13 +16,7 @@ import space.kscience.kmath.structures.Float64Buffer import kotlin.math.abs -public class QowRuns(public val runs: Int) : OptimizationFeature { - init { - require(runs >= 1) { "Number of runs must be more than zero" } - } - - override fun toString(): String = "QowRuns(runs=$runs)" -} +public object QowRuns: OptimizationAttribute /** @@ -69,7 +63,7 @@ public object QowOptimizer : Optimizer { } val prior: DifferentiableExpression? - get() = problem.getFeature>()?.withDefaultArgs(allParameters) + get() = problem.attributes[OptimizationPrior()]?.withDefaultArgs(allParameters) override fun toString(): String = freeParameters.toString() } @@ -176,7 +170,7 @@ public object QowOptimizer : Optimizer { fast: Boolean = false, ): QoWeight { - val logger = problem.getFeature() + val logger = problem.attributes[OptimizationLog] var dis: Double //discrepancy value @@ -231,7 +225,7 @@ public object QowOptimizer : Optimizer { } private fun QoWeight.covariance(): NamedMatrix { - val logger = problem.getFeature() + val logger = problem.attributes[OptimizationLog] logger?.log { """ @@ -257,11 +251,11 @@ public object QowOptimizer : Optimizer { } override suspend fun optimize(problem: XYFit): XYFit { - val qowRuns = problem.getFeature()?.runs ?: 2 - val iterations = problem.getFeature()?.maxIterations ?: 50 + val qowRuns = problem.attributes[QowRuns] ?: 2 + val iterations = problem.attributes[OptimizationIterations] ?: 50 - val freeParameters: Map = problem.getFeature()?.let { op -> - problem.startPoint.filterKeys { it in op.symbols } + val freeParameters: Map = problem.attributes[OptimizationParameters]?.let { symbols -> + problem.startPoint.filterKeys { it in symbols } } ?: problem.startPoint var qow = QoWeight(problem, freeParameters) @@ -271,6 +265,9 @@ public object QowOptimizer : Optimizer { res = qow.newtonianRun(maxSteps = iterations) } val covariance = res.covariance() - return res.problem.withFeature(OptimizationResult(res.freeParameters), OptimizationCovariance(covariance)) + return res.problem.withAttributes { + set(OptimizationResult(), res.freeParameters) + set(OptimizationCovariance(), covariance) + } } } \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt index cda37af4f..59143c338 100644 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt @@ -2,37 +2,41 @@ * Copyright 2018-2022 KMath contributors. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ -@file:OptIn(UnstableKMathAPI::class) +@file:OptIn(UnstableKMathAPI::class, UnstableKMathAPI::class) package space.kscience.kmath.optimization +import space.kscience.attributes.* import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.data.indices import space.kscience.kmath.expressions.* -import space.kscience.kmath.misc.FeatureSet import space.kscience.kmath.misc.Loggable +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExtendedField +import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.bindSymbol import kotlin.math.pow /** * Specify the way to compute distance from point to the curve as DifferentiableExpression */ -public interface PointToCurveDistance : OptimizationFeature { +public interface PointToCurveDistance { public fun distance(problem: XYFit, index: Int): DifferentiableExpression - public companion object { + public companion object : OptimizationAttribute { public val byY: PointToCurveDistance = object : PointToCurveDistance { override fun distance(problem: XYFit, index: Int): DifferentiableExpression { val x = problem.data.x[index] val y = problem.data.y[index] return object : DifferentiableExpression { + override val type: SafeType get() = DoubleField.type + override fun derivativeOrNull( symbols: List, ): Expression? = problem.model.derivativeOrNull(symbols)?.let { derivExpression -> - Expression { arguments -> + Expression(DoubleField.type) { arguments -> derivExpression.invoke(arguments + (Symbol.x to x)) } } @@ -51,18 +55,21 @@ public interface PointToCurveDistance : OptimizationFeature { * Compute a wight of the point. The more the weight, the more impact this point will have on the fit. * By default, uses Dispersion^-1 */ -public interface PointWeight : OptimizationFeature { +public interface PointWeight { public fun weight(problem: XYFit, index: Int): DifferentiableExpression - public companion object { + public companion object : OptimizationAttribute { public fun bySigma(sigmaSymbol: Symbol): PointWeight = object : PointWeight { override fun weight(problem: XYFit, index: Int): DifferentiableExpression = object : DifferentiableExpression { + override val type: SafeType get() = DoubleField.type + override fun invoke(arguments: Map): Double { return problem.data[sigmaSymbol]?.get(index)?.pow(-2) ?: 1.0 } - override fun derivativeOrNull(symbols: List): Expression = Expression { 0.0 } + override fun derivativeOrNull(symbols: List): Expression = + Expression(DoubleField.type) { 0.0 } } override fun toString(): String = "PointWeightBySigma($sigmaSymbol)" @@ -79,41 +86,52 @@ public interface PointWeight : OptimizationFeature { public class XYFit( public val data: XYColumnarData, public val model: DifferentiableExpression, - override val attributes: FeatureSet, + override val attributes: Attributes, internal val pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, internal val pointWeight: PointWeight = PointWeight.byYSigma, public val xSymbol: Symbol = Symbol.x, ) : OptimizationProblem { + + override val type: SafeType get() = Float64Field.type + public fun distance(index: Int): DifferentiableExpression = pointToCurveDistance.distance(this, index) public fun weight(index: Int): DifferentiableExpression = pointWeight.weight(this, index) } -public fun XYFit.withFeature(vararg features: OptimizationFeature): XYFit { - return XYFit(data, model, this.attributes.with(*features), pointToCurveDistance, pointWeight) -} + +public fun XYOptimization( + data: XYColumnarData, + model: DifferentiableExpression, + builder: AttributesBuilder.() -> Unit, +): XYFit = XYFit(data, model, Attributes(builder)) + +public fun XYFit.withAttributes( + modifier: AttributesBuilder.() -> Unit, +): XYFit = XYFit(data, model, attributes.modify(modifier), pointToCurveDistance, pointWeight, xSymbol) public suspend fun XYColumnarData.fitWith( optimizer: Optimizer, modelExpression: DifferentiableExpression, startingPoint: Map, - vararg features: OptimizationFeature = emptyArray(), + attributes: Attributes = Attributes.EMPTY, xSymbol: Symbol = Symbol.x, pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, pointWeight: PointWeight = PointWeight.byYSigma, ): XYFit { - var actualFeatures = FeatureSet.of(*features, OptimizationStartPoint(startingPoint)) - if (actualFeatures.getFeature() == null) { - actualFeatures = actualFeatures.with(OptimizationLog(Loggable.console)) - } val problem = XYFit( this, modelExpression, - actualFeatures, + attributes.modify { + set(::OptimizationStartPoint, startingPoint) + if (!hasAny()) { + set(OptimizationLog, Loggable.console) + } + }, pointToCurveDistance, pointWeight, - xSymbol + xSymbol, ) return optimizer.optimize(problem) } @@ -125,7 +143,7 @@ public suspend fun XYColumnarData.fitWith( optimizer: Optimizer, processor: AutoDiffProcessor, startingPoint: Map, - vararg features: OptimizationFeature = emptyArray(), + attributes: Attributes = Attributes.EMPTY, xSymbol: Symbol = Symbol.x, pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, pointWeight: PointWeight = PointWeight.byYSigma, @@ -140,7 +158,7 @@ public suspend fun XYColumnarData.fitWith( optimizer = optimizer, modelExpression = modelExpression, startingPoint = startingPoint, - features = features, + attributes = attributes, xSymbol = xSymbol, pointToCurveDistance = pointToCurveDistance, pointWeight = pointWeight @@ -152,7 +170,7 @@ public suspend fun XYColumnarData.fitWith( */ public val XYFit.chiSquaredOrNull: Double? get() { - val result = startPoint + (resultPointOrNull ?: return null) + val result = startPoint + (resultOrNull ?: return null) return data.indices.sumOf { index -> @@ -167,4 +185,4 @@ public val XYFit.chiSquaredOrNull: Double? } public val XYFit.dof: Int - get() = data.size - (getFeature()?.symbols?.size ?: startPoint.size) \ No newline at end of file + get() = data.size - (attributes[OptimizationParameters]?.size ?: startPoint.size) \ No newline at end of file diff --git a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt index 40081ed81..aa77cdda6 100644 --- a/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt +++ b/kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/logLikelihood.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.optimization +import space.kscience.attributes.AttributesBuilder +import space.kscience.attributes.SafeType import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.data.indices @@ -12,6 +14,7 @@ import space.kscience.kmath.expressions.DifferentiableExpression import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.derivative +import space.kscience.kmath.operations.Float64Field import kotlin.math.PI import kotlin.math.ln import kotlin.math.pow @@ -22,7 +25,9 @@ private val oneOver2Pi = 1.0 / sqrt(2 * PI) @UnstableKMathAPI internal fun XYFit.logLikelihood(): DifferentiableExpression = object : DifferentiableExpression { - override fun derivativeOrNull(symbols: List): Expression = Expression { arguments -> + override val type: SafeType get() = Float64Field.type + + override fun derivativeOrNull(symbols: List): Expression = Expression(type) { arguments -> data.indices.sumOf { index -> val d = distance(index)(arguments) val weight = weight(index)(arguments) @@ -53,14 +58,18 @@ internal fun XYFit.logLikelihood(): DifferentiableExpression = object : */ @UnstableKMathAPI public suspend fun Optimizer>.maximumLogLikelihood(problem: XYFit): XYFit { - val functionOptimization = FunctionOptimization(problem.attributes, problem.logLikelihood()) - val result = optimize(functionOptimization.withFeatures(FunctionOptimizationTarget.MAXIMIZE)) - return XYFit(problem.data, problem.model, result.attributes) + val functionOptimization = FunctionOptimization(problem.logLikelihood(), problem.attributes) + val result = optimize( + functionOptimization.withAttributes { + FunctionOptimizationTarget(OptimizationDirection.MAXIMIZE) + } + ) + return XYFit(problem.data,problem.model, result.attributes) } @UnstableKMathAPI public suspend fun Optimizer>.maximumLogLikelihood( data: XYColumnarData, model: DifferentiableExpression, - builder: XYOptimizationBuilder.() -> Unit, + builder: AttributesBuilder.() -> Unit, ): XYFit = maximumLogLikelihood(XYOptimization(data, model, builder)) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt index 8a97114c3..d960fe210 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensor.kt @@ -5,10 +5,12 @@ package space.kscience.kmath.tensors.core +import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.nd.MutableStructureNDOfDouble import space.kscience.kmath.nd.ShapeND +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.structures.* import space.kscience.kmath.tensors.core.internal.toPrettyString @@ -88,6 +90,8 @@ public open class DoubleTensor( final override val source: OffsetDoubleBuffer, ) : BufferedTensor(shape), MutableStructureNDOfDouble { + override val type: SafeType get() = DoubleField.type + init { require(linearSize == source.size) { "Source buffer size must be equal tensor size" } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt index 1066793a7..005bf5d2a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensor.kt @@ -5,8 +5,10 @@ package space.kscience.kmath.tensors.core +import space.kscience.attributes.SafeType import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.nd.ShapeND +import space.kscience.kmath.operations.IntRing import space.kscience.kmath.structures.* /** @@ -24,6 +26,8 @@ public class OffsetIntBuffer( require(offset + size <= source.size) { "Maximum index must be inside source dimension" } } + override val type: SafeType get() = IntRing.type + override fun set(index: Int, value: Int) { require(index in 0 until size) { "Index must be in [0, size)" } source[index + offset] = value @@ -83,6 +87,8 @@ public class IntTensor( require(linearSize == source.size) { "Source buffer size must be equal tensor size" } } + override val type: SafeType get() = IntRing.type + public constructor(shape: ShapeND, buffer: Int32Buffer) : this(shape, OffsetIntBuffer(buffer, 0, buffer.size)) @OptIn(PerformancePitfall::class)