This commit is contained in:
Alexander Nozik 2023-11-18 22:29:59 +03:00
parent 2f2f552648
commit 5c82a5e1fa
54 changed files with 541 additions and 433 deletions

View File

@ -3,7 +3,7 @@
## Unreleased ## Unreleased
### Added ### Added
- Explicit `SafeType` for algebras and buffers. - Reification. Explicit `SafeType` for algebras and buffers.
- Integer division algebras. - Integer division algebras.
- Float32 geometries. - Float32 geometries.
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers. - New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
@ -16,6 +16,7 @@
- kmath-geometry is split into `euclidean2d` and `euclidean3d` - kmath-geometry is split into `euclidean2d` and `euclidean3d`
- Features replaced with Attributes. - Features replaced with Attributes.
- Transposed refactored. - Transposed refactored.
- Kmath-memory is moved on top of core.
### Deprecated ### Deprecated

View File

@ -24,14 +24,3 @@ public interface AttributeWithDefault<T> : Attribute<T> {
*/ */
public interface SetAttribute<V> : Attribute<Set<V>> public interface SetAttribute<V> : Attribute<Set<V>>
/**
* An attribute that has a type parameter for value
* @param type parameter-type
*/
public abstract class PolymorphicAttribute<T>(public val type: SafeType<T>) : Attribute<T> {
override fun equals(other: Any?): Boolean = other != null &&
(this::class == other::class) &&
(other as? PolymorphicAttribute<*>)?.type == this.type
override fun hashCode(): Int = this::class.hashCode() + type.hashCode()
}

View File

@ -6,8 +6,14 @@
package space.kscience.attributes package space.kscience.attributes
/** /**
* A container for attributes. [attributes] could be made mutable by implementation * A container for [Attributes]
*/ */
public interface AttributeContainer { public interface AttributeContainer {
public val attributes: Attributes public val attributes: Attributes
} }
/**
* A scope, where attribute keys could be resolved
*/
public interface AttributeScope<O>

View File

@ -7,21 +7,27 @@ package space.kscience.attributes
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
@JvmInline /**
public value class Attributes internal constructor(public val content: Map<out Attribute<*>, Any?>) { * A set of attributes. The implementation must guarantee that [content] keys correspond to its value types.
*/
public interface Attributes {
public val content: Map<out Attribute<*>, Any?>
public val keys: Set<Attribute<*>> get() = content.keys public val keys: Set<Attribute<*>> get() = content.keys
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public operator fun <T> get(attribute: Attribute<T>): T? = content[attribute] as? T public operator fun <T> get(attribute: Attribute<T>): T? = content[attribute] as? T
override fun toString(): String = "Attributes(value=${content.entries})"
public companion object { public companion object {
public val EMPTY: Attributes = Attributes(emptyMap()) public val EMPTY: Attributes = AttributesImpl(emptyMap())
} }
} }
@JvmInline
internal value class AttributesImpl(override val content: Map<out Attribute<*>, Any?>) : Attributes {
override fun toString(): String = "Attributes(value=${content.entries})"
}
public fun Attributes.isEmpty(): Boolean = content.isEmpty() public fun Attributes.isEmpty(): Boolean = content.isEmpty()
/** /**
@ -33,19 +39,19 @@ public fun <T> Attributes.getOrDefault(attribute: AttributeWithDefault<T>): T =
* Check if there is an attribute that matches given key by type and adheres to [predicate]. * Check if there is an attribute that matches given key by type and adheres to [predicate].
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <T, reified A : Attribute<T>> Attributes.any(predicate: (value: T) -> Boolean): Boolean = public inline fun <T, reified A : Attribute<T>> Attributes.hasAny(predicate: (value: T) -> Boolean): Boolean =
content.any { (mapKey, mapValue) -> mapKey is A && predicate(mapValue as T) } content.any { (mapKey, mapValue) -> mapKey is A && predicate(mapValue as T) }
/** /**
* Check if there is an attribute of given type (subtypes included) * Check if there is an attribute of given type (subtypes included)
*/ */
public inline fun <T, reified A : Attribute<T>> Attributes.any(): Boolean = public inline fun <reified A : Attribute<*>> Attributes.hasAny(): Boolean =
content.any { (mapKey, _) -> mapKey is A } content.any { (mapKey, _) -> mapKey is A }
/** /**
* Check if [Attributes] contains a flag. Multiple keys that are instances of a flag could be present * Check if [Attributes] contains a flag. Multiple keys that are instances of a flag could be present
*/ */
public inline fun <reified A : FlagAttribute> Attributes.has(): Boolean = public inline fun <reified A : FlagAttribute> Attributes.hasFlag(): Boolean =
content.keys.any { it is A } content.keys.any { it is A }
/** /**
@ -54,7 +60,7 @@ public inline fun <reified A : FlagAttribute> Attributes.has(): Boolean =
public fun <T, A : Attribute<T>> Attributes.withAttribute( public fun <T, A : Attribute<T>> Attributes.withAttribute(
attribute: A, attribute: A,
attrValue: T, attrValue: T,
): Attributes = Attributes(content + (attribute to attrValue)) ): Attributes = AttributesImpl(content + (attribute to attrValue))
public fun <A : Attribute<Unit>> Attributes.withAttribute(attribute: A): Attributes = public fun <A : Attribute<Unit>> Attributes.withAttribute(attribute: A): Attributes =
withAttribute(attribute, Unit) withAttribute(attribute, Unit)
@ -62,7 +68,7 @@ public fun <A : Attribute<Unit>> Attributes.withAttribute(attribute: A): Attribu
/** /**
* Create a new [Attributes] by modifying the current one * Create a new [Attributes] by modifying the current one
*/ */
public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes = Attributes { public fun <T> Attributes.modify(block: AttributesBuilder<T>.() -> Unit): Attributes = Attributes<T> {
from(this@modify) from(this@modify)
block() block()
} }
@ -70,7 +76,7 @@ public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes =
/** /**
* Create new [Attributes] by removing [attribute] key * Create new [Attributes] by removing [attribute] key
*/ */
public fun Attributes.withoutAttribute(attribute: Attribute<*>): Attributes = Attributes(content.minus(attribute)) public fun Attributes.withoutAttribute(attribute: Attribute<*>): Attributes = AttributesImpl(content.minus(attribute))
/** /**
* Add an element to a [SetAttribute] * Add an element to a [SetAttribute]
@ -80,7 +86,7 @@ public fun <T, A : SetAttribute<T>> Attributes.withAttributeElement(
attrValue: T, attrValue: T,
): Attributes { ): Attributes {
val currentSet: Set<T> = get(attribute) ?: emptySet() val currentSet: Set<T> = get(attribute) ?: emptySet()
return Attributes( return AttributesImpl(
content + (attribute to (currentSet + attrValue)) content + (attribute to (currentSet + attrValue))
) )
} }
@ -93,9 +99,7 @@ public fun <T, A : SetAttribute<T>> Attributes.withoutAttributeElement(
attrValue: T, attrValue: T,
): Attributes { ): Attributes {
val currentSet: Set<T> = get(attribute) ?: emptySet() val currentSet: Set<T> = get(attribute) ?: emptySet()
return Attributes( return AttributesImpl(content + (attribute to (currentSet - attrValue)))
content + (attribute to (currentSet - attrValue))
)
} }
/** /**
@ -104,13 +108,13 @@ public fun <T, A : SetAttribute<T>> Attributes.withoutAttributeElement(
public fun <T, A : Attribute<T>> Attributes( public fun <T, A : Attribute<T>> Attributes(
attribute: A, attribute: A,
attrValue: T, attrValue: T,
): Attributes = Attributes(mapOf(attribute to attrValue)) ): Attributes = AttributesImpl(mapOf(attribute to attrValue))
/** /**
* Create Attributes with a single [Unit] valued attribute * Create Attributes with a single [Unit] valued attribute
*/ */
public fun <A : Attribute<Unit>> Attributes( public fun <A : Attribute<Unit>> Attributes(
attribute: A attribute: A,
): Attributes = Attributes(mapOf(attribute to Unit)) ): Attributes = AttributesImpl(mapOf(attribute to Unit))
public operator fun Attributes.plus(other: Attributes): Attributes = Attributes(content + other.content) public operator fun Attributes.plus(other: Attributes): Attributes = AttributesImpl(content + other.content)

View File

@ -10,19 +10,24 @@ package space.kscience.attributes
* *
* @param O type marker of an owner object, for which these attributes are made * @param O type marker of an owner object, for which these attributes are made
*/ */
public class TypedAttributesBuilder<in O> internal constructor(private val map: MutableMap<Attribute<*>, Any?>) { public class AttributesBuilder<out O> internal constructor(
private val map: MutableMap<Attribute<*>, Any?>,
) : Attributes {
public constructor() : this(mutableMapOf()) public constructor() : this(mutableMapOf())
@Suppress("UNCHECKED_CAST") override val content: Map<out Attribute<*>, Any?> get() = map
public operator fun <T> get(attribute: Attribute<T>): T? = map[attribute] as? T
public operator fun <T> set(attribute: Attribute<T>, value: T?) {
if (value == null) {
map.remove(attribute)
} else {
map[attribute] = value
}
}
public operator fun <V> Attribute<V>.invoke(value: V?) { public operator fun <V> Attribute<V>.invoke(value: V?) {
if (value == null) { set(this, value)
map.remove(this)
} else {
map[this] = value
}
} }
public fun from(attributes: Attributes) { public fun from(attributes: Attributes) {
@ -46,14 +51,8 @@ public class TypedAttributesBuilder<in O> internal constructor(private val map:
map[this] = currentSet - attrValue map[this] = currentSet - attrValue
} }
public fun build(): Attributes = Attributes(map) public fun build(): Attributes = AttributesImpl(map)
} }
public typealias AttributesBuilder = TypedAttributesBuilder<Any?> public inline fun <O> Attributes(builder: AttributesBuilder<O>.() -> Unit): Attributes =
AttributesBuilder<O>().apply(builder).build()
public fun AttributesBuilder(
attributes: Attributes,
): AttributesBuilder = AttributesBuilder(attributes.content.toMutableMap())
public inline fun Attributes(builder: AttributesBuilder.() -> Unit): Attributes =
AttributesBuilder().apply(builder).build()

View File

@ -0,0 +1,31 @@
/*
* Copyright 2018-2023 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.attributes
/**
* An attribute that has a type parameter for value
* @param type parameter-type
*/
public abstract class PolymorphicAttribute<T>(public val type: SafeType<T>) : Attribute<T> {
override fun equals(other: Any?): Boolean = other != null &&
(this::class == other::class) &&
(other as? PolymorphicAttribute<*>)?.type == this.type
override fun hashCode(): Int = this::class.hashCode() + type.hashCode()
}
/**
* Get a polymorphic attribute using attribute factory
*/
public operator fun <T> Attributes.get(attributeKeyBuilder: () -> PolymorphicAttribute<T>): T? = get(attributeKeyBuilder())
/**
* Set a polymorphic attribute using its factory
*/
public operator fun <O, T> AttributesBuilder<O>.set(attributeKeyBuilder: () -> PolymorphicAttribute<T>, value: T) {
set(attributeKeyBuilder(), value)
}

View File

@ -19,6 +19,8 @@ public class Ejml${type}Vector<out M : $ejmlMatrixType>(override val origin: M)
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
} }
override val type: SafeType<${type}> get() = safeTypeOf()
override operator fun get(index: Int): $type = origin[0, index] override operator fun get(index: Int): $type = origin[0, index]
}""" }"""
appendLine(text) appendLine(text)
@ -30,6 +32,8 @@ private fun Appendable.appendEjmlMatrix(type: String, ejmlMatrixType: String) {
* [EjmlMatrix] specialization for [$type]. * [EjmlMatrix] specialization for [$type].
*/ */
public class Ejml${type}Matrix<out M : $ejmlMatrixType>(override val origin: M) : EjmlMatrix<$type, M>(origin) { public class Ejml${type}Matrix<out M : $ejmlMatrixType>(override val origin: M) : EjmlMatrix<$type, M>(origin) {
override val type: SafeType<${type}> get() = safeTypeOf()
override operator fun get(i: Int, j: Int): $type = origin[i, j] override operator fun get(i: Int, j: Int): $type = origin[i, j]
}""" }"""
appendLine(text) appendLine(text)
@ -46,7 +50,9 @@ private fun Appendable.appendEjmlLinearSpace(
denseOps: String, denseOps: String,
isDense: Boolean, isDense: Boolean,
) { ) {
@Language("kotlin") val text = """/** @Language("kotlin") val text = """
/**
* [EjmlLinearSpace] implementation based on [CommonOps_$ops], [DecompositionFactory_${ops}] operations and * [EjmlLinearSpace] implementation based on [CommonOps_$ops], [DecompositionFactory_${ops}] operations and
* [${ejmlMatrixType}] matrices. * [${ejmlMatrixType}] matrices.
*/ */
@ -56,7 +62,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra},
*/ */
override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra
override val elementType: KType get() = typeOf<$type>() override val type: SafeType<${type}> get() = safeTypeOf()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when { override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when {
@ -385,6 +391,8 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.linear.* import space.kscience.kmath.linear.*
import space.kscience.kmath.linear.Matrix import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI

View File

@ -15,7 +15,7 @@ import space.kscience.kmath.operations.asIterable
import space.kscience.kmath.operations.toList import space.kscience.kmath.operations.toList
import space.kscience.kmath.optimization.FunctionOptimizationTarget import space.kscience.kmath.optimization.FunctionOptimizationTarget
import space.kscience.kmath.optimization.optimizeWith import space.kscience.kmath.optimization.optimizeWith
import space.kscience.kmath.optimization.resultPoint import space.kscience.kmath.optimization.result
import space.kscience.kmath.optimization.resultValue import space.kscience.kmath.optimization.resultValue
import space.kscience.kmath.random.RandomGenerator import space.kscience.kmath.random.RandomGenerator
import space.kscience.kmath.real.DoubleVector import space.kscience.kmath.real.DoubleVector
@ -98,7 +98,7 @@ suspend fun main() {
scatter { scatter {
mode = ScatterMode.lines mode = ScatterMode.lines
x(x) x(x)
y(x.map { result.resultPoint[a]!! * it.pow(2) + result.resultPoint[b]!! * it + 1 }) y(x.map { result.result[a]!! * it.pow(2) + result.result[b]!! * it + 1 })
name = "fit" name = "fit"
} }
} }

View File

@ -94,13 +94,13 @@ suspend fun main() {
scatter { scatter {
mode = ScatterMode.lines mode = ScatterMode.lines
x(x) x(x)
y(x.map { result.model(result.startPoint + result.resultPoint + (Symbol.x to it)) }) y(x.map { result.model(result.startPoint + result.result + (Symbol.x to it)) })
name = "fit" name = "fit"
} }
} }
br() br()
h3 { h3 {
+"Fit result: ${result.resultPoint}" +"Fit result: ${result.result}"
} }
h3 { h3 {
+"Chi2/dof = ${result.chiSquaredOrNull!! / result.dof}" +"Chi2/dof = ${result.chiSquaredOrNull!! / result.dof}"

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.ast package space.kscience.kmath.ast
import space.kscience.attributes.SafeType
import space.kscience.attributes.WithType
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
@ -15,7 +17,7 @@ import space.kscience.kmath.operations.NumericAlgebra
* *
* @param T the type. * @param T the type.
*/ */
public sealed interface TypedMst<T> { public sealed interface TypedMst<T> : WithType<T> {
/** /**
* A node containing a unary operation. * A node containing a unary operation.
* *
@ -24,8 +26,13 @@ public sealed interface TypedMst<T> {
* @property function The function implementing this operation. * @property function The function implementing this operation.
* @property value The argument of this operation. * @property value The argument of this operation.
*/ */
public class Unary<T>(public val operation: String, public val function: (T) -> T, public val value: TypedMst<T>) : public class Unary<T>(
TypedMst<T> { public val operation: String,
public val function: (T) -> T,
public val value: TypedMst<T>,
) : TypedMst<T> {
override val type: SafeType<T> get() = value.type
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other == null || this::class != other::class) return false if (other == null || this::class != other::class) return false
@ -59,6 +66,13 @@ public sealed interface TypedMst<T> {
public val left: TypedMst<T>, public val left: TypedMst<T>,
public val right: TypedMst<T>, public val right: TypedMst<T>,
) : TypedMst<T> { ) : TypedMst<T> {
init {
require(left.type==right.type){"Left and right expressions must be of the same type"}
}
override val type: SafeType<T> get() = left.type
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other == null || this::class != other::class) return false if (other == null || this::class != other::class) return false
@ -89,7 +103,12 @@ public sealed interface TypedMst<T> {
* @property value The held value. * @property value The held value.
* @property number The number this value corresponds. * @property number The number this value corresponds.
*/ */
public class Constant<T>(public val value: T, public val number: Number?) : TypedMst<T> { public class Constant<T>(
override val type: SafeType<T>,
public val value: T,
public val number: Number?,
) : TypedMst<T> {
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other == null || this::class != other::class) return false if (other == null || this::class != other::class) return false
@ -114,7 +133,7 @@ public sealed interface TypedMst<T> {
* @param T the type. * @param T the type.
* @property symbol The symbol of the variable. * @property symbol The symbol of the variable.
*/ */
public class Variable<T>(public val symbol: Symbol) : TypedMst<T> { public class Variable<T>(override val type: SafeType<T>, public val symbol: Symbol) : TypedMst<T> {
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
if (other == null || this::class != other::class) return false if (other == null || this::class != other::class) return false
@ -167,6 +186,7 @@ public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, vararg arguments: Pair
/** /**
* Interpret this [TypedMst] node as expression. * Interpret this [TypedMst] node as expression.
*/ */
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments -> public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> =
Expression(algebra.type) { arguments ->
interpret(algebra, arguments) interpret(algebra, arguments)
} }

View File

@ -16,6 +16,7 @@ import space.kscience.kmath.operations.bindSymbolOrNull
*/ */
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) { public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
is MST.Numeric -> TypedMst.Constant( is MST.Numeric -> TypedMst.Constant(
algebra.type,
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"), (algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
value, value,
) )
@ -27,7 +28,7 @@ public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (t
arg.value, arg.value,
) )
TypedMst.Constant(value, if (value is Number) value else null) TypedMst.Constant(algebra.type, value, if (value is Number) value else null)
} }
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg) else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
@ -59,7 +60,7 @@ public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (t
) )
} }
TypedMst.Constant(value, if (value is Number) value else null) TypedMst.Constant(algebra.type, value, if (value is Number) value else null)
} }
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary( algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
@ -84,8 +85,8 @@ public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (t
val boundSymbol = algebra.bindSymbolOrNull(this) val boundSymbol = algebra.bindSymbolOrNull(this)
if (boundSymbol != null) if (boundSymbol != null)
TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null) TypedMst.Constant(algebra.type, boundSymbol, if (boundSymbol is Number) boundSymbol else null)
else else
TypedMst.Variable(this) TypedMst.Variable(algebra.type, this)
} }
} }

View File

@ -22,7 +22,7 @@ import space.kscience.kmath.operations.Algebra
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> { public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
val typed = evaluateConstants(algebra) val typed = evaluateConstants(algebra)
if (typed is TypedMst.Constant<T>) return Expression { typed.value } if (typed is TypedMst.Constant<T>) return Expression(algebra.type) { typed.value }
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) { fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
is TypedMst.Constant -> constant(node.value) is TypedMst.Constant -> constant(node.value)
@ -36,7 +36,7 @@ public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T>
) )
} }
return ESTreeBuilder { visit(typed) }.instance return ESTreeBuilder(algebra.type) { visit(typed) }.instance
} }
/** /**

View File

@ -5,13 +5,22 @@
package space.kscience.kmath.estree.internal package space.kscience.kmath.estree.internal
import space.kscience.attributes.SafeType
import space.kscience.attributes.WithType
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.internal.astring.generate import space.kscience.kmath.internal.astring.generate
import space.kscience.kmath.internal.estree.* import space.kscience.kmath.internal.estree.*
internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExpression) { internal class ESTreeBuilder<T>(
private class GeneratedExpression<T>(val executable: dynamic, val constants: Array<dynamic>) : Expression<T> { override val type: SafeType<T>,
val bodyCallback: ESTreeBuilder<T>.() -> BaseExpression,
) : WithType<T> {
private class GeneratedExpression<T>(
override val type: SafeType<T>,
val executable: dynamic,
val constants: Array<dynamic>,
) : Expression<T> {
@Suppress("UNUSED_VARIABLE") @Suppress("UNUSED_VARIABLE")
override fun invoke(arguments: Map<Symbol, T>): T { override fun invoke(arguments: Map<Symbol, T>): T {
val e = executable val e = executable
@ -30,7 +39,7 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
) )
val code = generate(node) val code = generate(node)
GeneratedExpression(js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray()) GeneratedExpression(type, js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray())
} }
private val constants = mutableListOf<Any>() private val constants = mutableListOf<Any>()

View File

@ -29,7 +29,7 @@ import space.kscience.kmath.operations.Int64Ring
@PublishedApi @PublishedApi
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> { internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
val typed = evaluateConstants(algebra) val typed = evaluateConstants(algebra)
if (typed is TypedMst.Constant<T>) return Expression { typed.value } if (typed is TypedMst.Constant<T>) return Expression(algebra.type) { typed.value }
fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) { fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
is TypedMst.Unary -> variablesVisitor(node.value) is TypedMst.Unary -> variablesVisitor(node.value)

View File

@ -8,10 +8,13 @@
package space.kscience.kmath.commons.expressions package space.kscience.kmath.commons.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOps import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.structures.MutableBufferFactory
/** /**
* A field over commons-math [DerivativeStructure]. * A field over commons-math [DerivativeStructure].
@ -26,6 +29,9 @@ public class CmDsField(
bindings: Map<Symbol, Double>, bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>, ) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
NumbersAddOps<DerivativeStructure> { NumbersAddOps<DerivativeStructure> {
override val bufferFactory: MutableBufferFactory<DerivativeStructure> = MutableBufferFactory()
public val numberOfVariables: Int = bindings.size public val numberOfVariables: Int = bindings.size
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
@ -77,7 +83,9 @@ public class CmDsField(
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right) override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure =
left.multiply(right)
override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right) override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
@ -125,13 +133,16 @@ public object CmDsProcessor : AutoDiffProcessor<Double, DerivativeStructure, CmD
public class CmDsExpression( public class CmDsExpression(
public val function: CmDsField.() -> DerivativeStructure, public val function: CmDsField.() -> DerivativeStructure,
) : DifferentiableExpression<Double> { ) : DifferentiableExpression<Double> {
override val type: SafeType<Double> get() = DoubleField.type
override operator fun invoke(arguments: Map<Symbol, Double>): Double = override operator fun invoke(arguments: Map<Symbol, Double>): Double =
CmDsField(0, arguments).function().value CmDsField(0, arguments).function().value
/** /**
* Get the derivative expression with given orders * Get the derivative expression with given orders
*/ */
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments -> override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression(type) { arguments ->
with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) } with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) }
} }
} }

View File

@ -16,7 +16,7 @@ public class CMGaussRuleIntegrator(
private var type: GaussRule = GaussRule.LEGENDRE, private var type: GaussRule = GaussRule.LEGENDRE,
) : UnivariateIntegrator<Double> { ) : UnivariateIntegrator<Double> {
override fun process(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> { override fun integrate(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> {
val range = integrand[IntegrationRange] val range = integrand[IntegrationRange]
?: error("Integration range is not provided") ?: error("Integration range is not provided")
val integrator: GaussIntegrator = getIntegrator(range) val integrator: GaussIntegrator = getIntegrator(range)
@ -79,7 +79,7 @@ public class CMGaussRuleIntegrator(
numPoints: Int = 100, numPoints: Int = 100,
type: GaussRule = GaussRule.LEGENDRE, type: GaussRule = GaussRule.LEGENDRE,
function: (Double) -> Double, function: (Double) -> Double,
): Double = CMGaussRuleIntegrator(numPoints, type).process( ): Double = CMGaussRuleIntegrator(numPoints, type).integrate(
UnivariateIntegrand({IntegrationRange(range)},function) UnivariateIntegrand({IntegrationRange(range)},function)
).value ).value
} }

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.commons.integration
import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator import org.apache.commons.math3.analysis.integration.SimpsonIntegrator
import space.kscience.attributes.Attributes
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.integration.* import space.kscience.kmath.integration.*
import org.apache.commons.math3.analysis.integration.UnivariateIntegrator as CMUnivariateIntegrator import org.apache.commons.math3.analysis.integration.UnivariateIntegrator as CMUnivariateIntegrator
@ -19,7 +20,7 @@ public class CMIntegrator(
public val integratorBuilder: (Integrand<Double>) -> CMUnivariateIntegrator, public val integratorBuilder: (Integrand<Double>) -> CMUnivariateIntegrator,
) : UnivariateIntegrator<Double> { ) : UnivariateIntegrator<Double> {
override fun process(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> { override fun integrate(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> {
val integrator = integratorBuilder(integrand) val integrator = integratorBuilder(integrand)
val maxCalls = integrand[IntegrandMaxCalls] ?: defaultMaxCalls val maxCalls = integrand[IntegrandMaxCalls] ?: defaultMaxCalls
val remainingCalls = maxCalls - integrand.calls val remainingCalls = maxCalls - integrand.calls
@ -73,15 +74,9 @@ public class CMIntegrator(
} }
@UnstableKMathAPI @UnstableKMathAPI
public var MutableList<IntegrandFeature>.targetAbsoluteAccuracy: Double? public val Attributes.targetAbsoluteAccuracy: Double?
get() = filterIsInstance<IntegrandAbsoluteAccuracy>().lastOrNull()?.accuracy get() = get(IntegrandAbsoluteAccuracy)
set(value) {
value?.let { add(IntegrandAbsoluteAccuracy(value)) }
}
@UnstableKMathAPI @UnstableKMathAPI
public var MutableList<IntegrandFeature>.targetRelativeAccuracy: Double? public val Attributes.targetRelativeAccuracy: Double?
get() = filterIsInstance<IntegrandRelativeAccuracy>().lastOrNull()?.accuracy get() = get(IntegrandRelativeAccuracy)
set(value) {
value?.let { add(IntegrandRelativeAccuracy(value)) }
}

View File

@ -6,18 +6,20 @@
package space.kscience.kmath.commons.linear package space.kscience.kmath.commons.linear
import org.apache.commons.math3.linear.* import org.apache.commons.math3.linear.*
import org.apache.commons.math3.linear.LUDecomposition
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.linear.* import space.kscience.kmath.linear.*
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureAttribute import space.kscience.kmath.nd.StructureAttribute
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.Float64Buffer import space.kscience.kmath.structures.Float64Buffer
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.cast import kotlin.reflect.cast
import kotlin.reflect.typeOf
public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> { public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
override val type: SafeType<Double> get() = DoubleField.type
override val rowNum: Int get() = origin.rowDimension override val rowNum: Int get() = origin.rowDimension
override val colNum: Int get() = origin.columnDimension override val colNum: Int get() = origin.columnDimension
@ -26,6 +28,7 @@ public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
@JvmInline @JvmInline
public value class CMVector(public val origin: RealVector) : Point<Double> { public value class CMVector(public val origin: RealVector) : Point<Double> {
override val type: SafeType<Double> get() = DoubleField.type
override val size: Int get() = origin.dimension override val size: Int get() = origin.dimension
override operator fun get(index: Int): Double = origin.getEntry(index) override operator fun get(index: Int): Double = origin.getEntry(index)
@ -40,7 +43,7 @@ public fun RealVector.toPoint(): CMVector = CMVector(this)
public object CMLinearSpace : LinearSpace<Double, Float64Field> { public object CMLinearSpace : LinearSpace<Double, Float64Field> {
override val elementAlgebra: Float64Field get() = Float64Field override val elementAlgebra: Float64Field get() = Float64Field
override val elementType: KType = typeOf<Double>() override val type: SafeType<Double> get() = DoubleField.type
override fun buildMatrix( override fun buildMatrix(
rows: Int, rows: Int,
@ -102,19 +105,14 @@ public object CMLinearSpace : LinearSpace<Double, Float64Field> {
override fun Double.times(v: Point<Double>): CMVector = override fun Double.times(v: Point<Double>): CMVector =
v * this v * this
@UnstableKMathAPI override fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<Double>, attribute: A): V? {
override fun <F : StructureAttribute> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
//Return the feature if it is intrinsic to the structure
structure.getFeature(type)?.let { return it }
val origin = structure.toCM().origin val origin = structure.toCM().origin
return when (type) { return when (attribute) {
IsDiagonal::class -> if (origin is DiagonalMatrix) IsDiagonal else null IsDiagonal -> if (origin is DiagonalMatrix) IsDiagonal else null
Determinant -> LUDecomposition(origin).determinant
Determinant::class, LupDecompositionAttribute::class -> object : LUP -> GenericLupDecomposition {
Determinant<Double>,
LupDecompositionAttribute<Double> {
private val lup by lazy { LUDecomposition(origin) } private val lup by lazy { LUDecomposition(origin) }
override val determinant: Double by lazy { lup.determinant } override val determinant: Double by lazy { lup.determinant }
override val l: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(lup.l).withAttribute(LowerTriangular) } override val l: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(lup.l).withAttribute(LowerTriangular) }
@ -122,20 +120,24 @@ public object CMLinearSpace : LinearSpace<Double, Float64Field> {
override val p: Matrix<Double> by lazy { CMMatrix(lup.p) } override val p: Matrix<Double> by lazy { CMMatrix(lup.p) }
} }
CholeskyDecompositionAttribute::class -> object : CholeskyDecompositionAttribute<Double> { CholeskyDecompositionAttribute -> object : CholeskyDecompositionAttribute<Double> {
override val l: Matrix<Double> by lazy<Matrix<Double>> { override val l: Matrix<Double> by lazy<Matrix<Double>> {
val cholesky = CholeskyDecomposition(origin) val cholesky = CholeskyDecomposition(origin)
CMMatrix(cholesky.l).withAttribute(LowerTriangular) CMMatrix(cholesky.l).withAttribute(LowerTriangular)
} }
} }
QRDecompositionAttribute::class -> object : QRDecompositionAttribute<Double> { QRDecompositionAttribute -> object : QRDecompositionAttribute<Double> {
private val qr by lazy { QRDecomposition(origin) } private val qr by lazy { QRDecomposition(origin) }
override val q: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(qr.q).withAttribute(OrthogonalAttribute) } override val q: Matrix<Double> by lazy<Matrix<Double>> {
CMMatrix(qr.q).withAttribute(
OrthogonalAttribute
)
}
override val r: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(qr.r).withAttribute(UpperTriangular) } override val r: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(qr.r).withAttribute(UpperTriangular) }
} }
SVDAttribute::class -> object : SVDAttribute<Double> { SVDAttribute -> object : SVDAttribute<Double> {
private val sv by lazy { SingularValueDecomposition(origin) } private val sv by lazy { SingularValueDecomposition(origin) }
override val u: Matrix<Double> by lazy { CMMatrix(sv.u) } override val u: Matrix<Double> by lazy { CMMatrix(sv.u) }
override val s: Matrix<Double> by lazy { CMMatrix(sv.s) } override val s: Matrix<Double> by lazy { CMMatrix(sv.s) }

View File

@ -13,6 +13,8 @@ import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
import space.kscience.attributes.AttributesBuilder
import space.kscience.attributes.SetAttribute
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.SymbolIndexer import space.kscience.kmath.expressions.SymbolIndexer
@ -26,34 +28,25 @@ import kotlin.reflect.KClass
public operator fun PointValuePair.component1(): DoubleArray = point public operator fun PointValuePair.component1(): DoubleArray = point
public operator fun PointValuePair.component2(): Double = value public operator fun PointValuePair.component2(): Double = value
public class CMOptimizerEngine(public val optimizerBuilder: () -> MultivariateOptimizer) : OptimizationFeature { public object CMOptimizerEngine: OptimizationAttribute<() -> MultivariateOptimizer>
override fun toString(): String = "CMOptimizer($optimizerBuilder)"
}
/** /**
* Specify a Commons-maths optimization engine * Specify a Commons-maths optimization engine
*/ */
public fun FunctionOptimizationBuilder<Double>.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) { public fun AttributesBuilder<FunctionOptimization<Double>>.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) {
addFeature(CMOptimizerEngine(optimizerBuilder)) set(CMOptimizerEngine, optimizerBuilder)
} }
public class CMOptimizerData(public val data: List<SymbolIndexer.() -> OptimizationData>) : OptimizationFeature { public object CMOptimizerData: SetAttribute<SymbolIndexer.() -> OptimizationData>
public constructor(vararg data: (SymbolIndexer.() -> OptimizationData)) : this(data.toList())
override fun toString(): String = "CMOptimizerData($data)"
}
/** /**
* Specify Commons-maths optimization data. * Specify Commons-maths optimization data.
*/ */
public fun FunctionOptimizationBuilder<Double>.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) { public fun AttributesBuilder<FunctionOptimization<Double>>.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) {
updateFeature<CMOptimizerData> { CMOptimizerData.add(data)
val newData = (it?.data ?: emptyList()) + data
CMOptimizerData(newData)
}
} }
public fun FunctionOptimizationBuilder<Double>.simplexSteps(vararg steps: Pair<Symbol, Double>) { public fun AttributesBuilder<FunctionOptimization<Double>>.simplexSteps(vararg steps: Pair<Symbol, Double>) {
//TODO use convergence checker from features //TODO use convergence checker from features
cmEngine { SimplexOptimizer(CMOptimizer.defaultConvergenceChecker) } cmEngine { SimplexOptimizer(CMOptimizer.defaultConvergenceChecker) }
cmOptimizationData { NelderMeadSimplex(mapOf(*steps).toDoubleArray()) } cmOptimizationData { NelderMeadSimplex(mapOf(*steps).toDoubleArray()) }
@ -78,8 +71,8 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
): FunctionOptimization<Double> { ): FunctionOptimization<Double> {
val startPoint = problem.startPoint val startPoint = problem.startPoint
val parameters = problem.getFeature<OptimizationParameters>()?.symbols val parameters = problem.attributes[OptimizationParameters]
?: problem.getFeature<OptimizationStartPoint<Double>>()?.point?.keys ?: problem.attributes[OptimizationStartPoint<Double>()]?.keys
?: startPoint.keys ?: startPoint.keys
@ -90,7 +83,7 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
DEFAULT_MAX_ITER DEFAULT_MAX_ITER
) )
val cmOptimizer: MultivariateOptimizer = problem.getFeature<CMOptimizerEngine>()?.optimizerBuilder?.invoke() val cmOptimizer: MultivariateOptimizer = problem.attributes[CMOptimizerEngine]?.invoke()
?: NonLinearConjugateGradientOptimizer( ?: NonLinearConjugateGradientOptimizer(
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
convergenceChecker convergenceChecker
@ -123,7 +116,7 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
} }
addOptimizationData(gradientFunction) addOptimizationData(gradientFunction)
val logger = problem.getFeature<OptimizationLog>() val logger = problem.attributes[OptimizationLog]
for (feature in problem.attributes) { for (feature in problem.attributes) {
when (feature) { when (feature) {
@ -139,7 +132,7 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
} }
val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray()) val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray())
return problem.withFeatures(OptimizationResult(point.toMap()), OptimizationValue(value)) return problem.withAttributes(OptimizationResult(point.toMap()), OptimizationValue(value))
} }
} }
} }

View File

@ -31,7 +31,7 @@ internal class OptimizeTest {
@Test @Test
fun testGradientOptimization() = runBlocking { fun testGradientOptimization() = runBlocking {
val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0) val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0)
println(result.resultPoint) println(result.result)
println(result.resultValue) println(result.resultValue)
} }
@ -42,7 +42,7 @@ internal class OptimizeTest {
//this sets simplex optimizer //this sets simplex optimizer
} }
println(result.resultPoint) println(result.result)
println(result.resultValue) println(result.resultValue)
} }

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -331,10 +332,13 @@ public class DerivativeStructureRingExpression<T, A>(
public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory, public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
public val function: DSRing<T, A>.() -> DS<T, A>, public val function: DSRing<T, A>.() -> DS<T, A>,
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> { ) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
override val type: SafeType<T> get() = elementBufferFactory.type
override operator fun invoke(arguments: Map<Symbol, T>): T = override operator fun invoke(arguments: Map<Symbol, T>): T =
DSRing(algebra, 0, arguments).function().value DSRing(algebra, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments -> override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression(type) { arguments ->
with( with(
DSRing( DSRing(
algebra, algebra,
@ -443,10 +447,13 @@ public class DSFieldExpression<T, A : ExtendedField<T>>(
public val algebra: A, public val algebra: A,
public val function: DSField<T, A>.() -> DS<T, A>, public val function: DSField<T, A>.() -> DS<T, A>,
) : DifferentiableExpression<T> { ) : DifferentiableExpression<T> {
override val type: SafeType<T> get() = algebra.type
override operator fun invoke(arguments: Map<Symbol, T>): T = override operator fun invoke(arguments: Map<Symbol, T>): T =
DSField(algebra, 0, arguments).function().value DSField(algebra, 0, arguments).function().value
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments -> override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression(type) { arguments ->
DSField( DSField(
algebra, algebra,
symbols.size, symbols.size,

View File

@ -5,8 +5,13 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.attributes.SafeType
import space.kscience.attributes.WithType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.LongRing
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
@ -15,7 +20,7 @@ import kotlin.properties.ReadOnlyProperty
* *
* @param T the type this expression takes as argument and returns. * @param T the type this expression takes as argument and returns.
*/ */
public fun interface Expression<T> { public interface Expression<T> : WithType<T> {
/** /**
* Calls this expression from arguments. * Calls this expression from arguments.
* *
@ -25,11 +30,20 @@ public fun interface Expression<T> {
public operator fun invoke(arguments: Map<Symbol, T>): T public operator fun invoke(arguments: Map<Symbol, T>): T
} }
public fun <T> Expression(type: SafeType<T>, block: (Map<Symbol, T>) -> T): Expression<T> = object : Expression<T> {
override fun invoke(arguments: Map<Symbol, T>): T = block(arguments)
override val type: SafeType<T> = type
}
/** /**
* Specialization of [Expression] for [Double] allowing better performance because of using array. * Specialization of [Expression] for [Double] allowing better performance because of using array.
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface DoubleExpression : Expression<Double> { public interface DoubleExpression : Expression<Double> {
override val type: SafeType<Double> get() = DoubleField.type
/** /**
* The indexer of this expression's arguments that should be used to build array for [invoke]. * The indexer of this expression's arguments that should be used to build array for [invoke].
* *
@ -59,6 +73,9 @@ public interface DoubleExpression : Expression<Double> {
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface IntExpression : Expression<Int> { public interface IntExpression : Expression<Int> {
override val type: SafeType<Int> get() = IntRing.type
/** /**
* The indexer of this expression's arguments that should be used to build array for [invoke]. * The indexer of this expression's arguments that should be used to build array for [invoke].
* *
@ -88,6 +105,9 @@ public interface IntExpression : Expression<Int> {
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface LongExpression : Expression<Long> { public interface LongExpression : Expression<Long> {
override val type: SafeType<Long> get() = LongRing.type
/** /**
* The indexer of this expression's arguments that should be used to build array for [invoke]. * The indexer of this expression's arguments that should be used to build array for [invoke].
* *
@ -158,7 +178,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
) )
/** /**
* Calls this expression without providing any arguments. * Calls this expression without providing any arguments.
* *

View File

@ -5,10 +5,15 @@
package space.kscience.kmath.expressions package space.kscience.kmath.expressions
import space.kscience.attributes.SafeType
public class ExpressionWithDefault<T>( public class ExpressionWithDefault<T>(
private val origin: Expression<T>, private val origin: Expression<T>,
private val defaultArgs: Map<Symbol, T>, private val defaultArgs: Map<Symbol, T>,
) : Expression<T> { ) : Expression<T> {
override val type: SafeType<T>
get() = origin.type
override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments) override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments)
} }
@ -21,6 +26,9 @@ public class DiffExpressionWithDefault<T>(
private val defaultArgs: Map<Symbol, T>, private val defaultArgs: Map<Symbol, T>,
) : DifferentiableExpression<T> { ) : DifferentiableExpression<T> {
override val type: SafeType<T>
get() = origin.type
override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments) override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments)
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? = override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? =

View File

@ -23,26 +23,27 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
/** /**
* Builds an Expression of constant expression that does not depend on arguments. * Builds an Expression of constant expression that does not depend on arguments.
*/ */
override fun const(value: T): Expression<T> = Expression { value } override fun const(value: T): Expression<T> = Expression(algebra.type) { value }
/** /**
* Builds an Expression to access a variable. * Builds an Expression to access a variable.
*/ */
override fun bindSymbolOrNull(value: String): Expression<T>? = Expression { arguments -> override fun bindSymbolOrNull(value: String): Expression<T>? = Expression(algebra.type) { arguments ->
algebra.bindSymbolOrNull(value) algebra.bindSymbolOrNull(value)
?: arguments[StringSymbol(value)] ?: arguments[StringSymbol(value)]
?: error("Symbol '$value' is not supported in $this") ?: error("Symbol '$value' is not supported in $this")
} }
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> = override fun binaryOperationFunction(
{ left, right -> operation: String,
Expression { arguments -> ): (left: Expression<T>, right: Expression<T>) -> Expression<T> = { left, right ->
Expression(algebra.type) { arguments ->
algebra.binaryOperationFunction(operation)(left(arguments), right(arguments)) algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
} }
} }
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg -> override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) } Expression(algebra.type) { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
} }
} }
@ -124,7 +125,7 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
super<FunctionalExpressionRing>.binaryOperationFunction(operation) super<FunctionalExpressionRing>.binaryOperationFunction(operation)
override fun scale(a: Expression<T>, value: Double): Expression<T> = algebra { override fun scale(a: Expression<T>, value: Double): Expression<T> = algebra {
Expression { args -> a(args) * value } Expression(algebra.type) { args -> a(args) * value }
} }
override fun bindSymbolOrNull(value: String): Expression<T>? = override fun bindSymbolOrNull(value: String): Expression<T>? =

View File

@ -108,4 +108,4 @@ public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol,
* Interpret this [MST] as expression. * Interpret this [MST] as expression.
*/ */
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> = public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
Expression { arguments -> interpret(algebra, arguments) } Expression(algebra.type) { arguments -> interpret(algebra, arguments) }

View File

@ -243,12 +243,15 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
public val field: F, public val field: F,
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>, public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
) : FirstDerivativeExpression<T>() { ) : FirstDerivativeExpression<T>() {
override val type: SafeType<T> get() = this.field.type
override operator fun invoke(arguments: Map<Symbol, T>): T { override operator fun invoke(arguments: Map<Symbol, T>): T {
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
return SimpleAutoDiffField(field, arguments).function().value return SimpleAutoDiffField(field, arguments).function().value
} }
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments -> override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression(type) { arguments ->
//val bindings = arguments.entries.map { it.key.bind(it.value) } //val bindings = arguments.entries.map { it.key.bind(it.value) }
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function) val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
derivationResult.derivative(symbol) derivationResult.derivative(symbol)

View File

@ -11,7 +11,7 @@ package space.kscience.kmath.linear
* *
* @param T the type of items. * @param T the type of items.
*/ */
public interface LinearSolver<T : Any> { public interface LinearSolver<T> {
/** /**
* Solve a dot x = b matrix equation and return x * Solve a dot x = b matrix equation and return x
*/ */

View File

@ -5,10 +5,7 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.attributes.Attributes import space.kscience.attributes.*
import space.kscience.attributes.SafeType
import space.kscience.attributes.WithType
import space.kscience.attributes.withAttribute
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.BufferRingOps import space.kscience.kmath.operations.BufferRingOps
@ -31,19 +28,13 @@ public typealias MutableMatrix<T> = MutableStructure2D<T>
*/ */
public typealias Point<T> = Buffer<T> public typealias Point<T> = Buffer<T>
/**
* A marker interface for algebras that operate on matrices
* @param T type of matrix element
*/
public interface MatrixOperations<T> : WithType<T>
/** /**
* Basic operations on matrices and vectors. * Basic operations on matrices and vectors.
* *
* @param T the type of items in the matrices. * @param T the type of items in the matrices.
* @param A the type of ring over [T]. * @param A the type of ring over [T].
*/ */
public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> { public interface LinearSpace<T, out A : Ring<T>> : MatrixScope<T> {
public val elementAlgebra: A public val elementAlgebra: A
override val type: SafeType<T> get() = elementAlgebra.type override val type: SafeType<T> get() = elementAlgebra.type
@ -177,10 +168,10 @@ public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> {
/** /**
* Compute an [attribute] value for given [structure]. Return null if the attribute could not be computed. * Compute an [attribute] value for given [structure]. Return null if the attribute could not be computed.
*/ */
public fun <V, A : StructureAttribute<V>> computeAttribute(structure: StructureND<*>, attribute: A): V? = null public fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<T>, attribute: A): V? = null
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, A : StructureAttribute<V>> StructureND<*>.getOrComputeAttribute(attribute: A): V? { public fun <V, A : StructureAttribute<V>> Structure2D<T>.getOrComputeAttribute(attribute: A): V? {
return attributes[attribute] ?: computeAttribute(this, attribute) return attributes[attribute] ?: computeAttribute(this, attribute)
} }
@ -225,7 +216,7 @@ public inline operator fun <LS : LinearSpace<*, *>, R> LS.invoke(block: LS.() ->
/** /**
* Convert matrix to vector if it is possible. * Convert matrix to vector if it is possible.
*/ */
public fun <T : Any> Matrix<T>.asVector(): Point<T> = public fun <T> Matrix<T>.asVector(): Point<T> =
if (this.colNum == 1) as1D() if (this.colNum == 1) as1D()
else error("Can't convert matrix with more than one column to vector") else error("Can't convert matrix with more than one column to vector")
@ -236,4 +227,4 @@ public fun <T : Any> Matrix<T>.asVector(): Point<T> =
* @receiver a buffer. * @receiver a buffer.
* @return the new matrix. * @return the new matrix.
*/ */
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(type, size, 1) { i, _ -> get(i) } public fun <T> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(type, size, 1) { i, _ -> get(i) }

View File

@ -9,12 +9,20 @@ package space.kscience.kmath.linear
import space.kscience.attributes.Attributes import space.kscience.attributes.Attributes
import space.kscience.attributes.PolymorphicAttribute import space.kscience.attributes.PolymorphicAttribute
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
public interface LupDecomposition<T> {
public val linearSpace: LinearSpace<T, Field<T>>
public val elementAlgebra: Field<T> get() = linearSpace.elementAlgebra
public val pivot: IntBuffer
public val l: Matrix<T>
public val u: Matrix<T>
}
/** /**
* Matrices with this feature support LU factorization with partial pivoting: *[p] &middot; a = [l] &middot; [u]* where * Matrices with this feature support LU factorization with partial pivoting: *[p] &middot; a = [l] &middot; [u]* where
* *a* is the owning matrix. * *a* is the owning matrix.
@ -22,15 +30,14 @@ import space.kscience.kmath.structures.*
* @param T the type of matrices' items. * @param T the type of matrices' items.
* @param lu combined L and U matrix * @param lu combined L and U matrix
*/ */
public class LupDecomposition<T>( public class GenericLupDecomposition<T>(
public val linearSpace: LinearSpace<T, Ring<T>>, override val linearSpace: LinearSpace<T, Field<T>>,
private val lu: Matrix<T>, private val lu: Matrix<T>,
public val pivot: IntBuffer, override val pivot: IntBuffer,
private val even: Boolean, private val even: Boolean,
) { ) : LupDecomposition<T> {
public val elementAlgebra: Ring<T> get() = linearSpace.elementAlgebra
public val l: Matrix<T> override val l: Matrix<T>
get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(LowerTriangular)) { i, j -> get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(LowerTriangular)) { i, j ->
when { when {
j < i -> lu[i, j] j < i -> lu[i, j]
@ -39,7 +46,7 @@ public class LupDecomposition<T>(
} }
} }
public val u: Matrix<T> override val u: Matrix<T>
get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(UpperTriangular)) { i, j -> get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(UpperTriangular)) { i, j ->
if (j >= i) lu[i, j] else elementAlgebra.zero if (j >= i) lu[i, j] else elementAlgebra.zero
} }
@ -55,13 +62,12 @@ public class LupDecomposition<T>(
} }
public class LupDecompositionAttribute<T> :
public class LupDecompositionAttribute<T>(type: SafeType<LupDecomposition<T>>) : PolymorphicAttribute<LupDecomposition<T>>(safeTypeOf()),
PolymorphicAttribute<LupDecomposition<T>>(type),
MatrixAttribute<LupDecomposition<T>> MatrixAttribute<LupDecomposition<T>>
public val <T> MatrixOperations<T>.LUP: LupDecompositionAttribute<T> public val <T> MatrixScope<T>.LUP: LupDecompositionAttribute<T>
get() = LupDecompositionAttribute(safeTypeOf()) get() = LupDecompositionAttribute()
@PublishedApi @PublishedApi
internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T = internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
@ -142,18 +148,17 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
for (row in col + 1 until m) lu[row, col] /= luDiag for (row in col + 1 until m) lu[row, col] /= luDiag
} }
return LupDecomposition(this@lup, lu.toStructure2D(), pivot.asBuffer(), even) return GenericLupDecomposition(this@lup, lu.toStructure2D(), pivot.asBuffer(), even)
} }
} }
public fun LinearSpace<Double, Float64Field>.lup( public fun LinearSpace<Double, Float64Field>.lup(
matrix: Matrix<Double>, matrix: Matrix<Double>,
singularityThreshold: Double = 1e-11, singularityThreshold: Double = 1e-11,
): LupDecomposition<Double> = lup(matrix) { it < singularityThreshold } ): LupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
internal fun <T : Any, A : Field<T>> LinearSpace<T, A>.solve( internal fun <T> LinearSpace<T, Field<T>>.solve(
lup: LupDecomposition<T>, lup: LupDecomposition<T>,
matrix: Matrix<T>, matrix: Matrix<T>,
): Matrix<T> { ): Matrix<T> {
@ -205,7 +210,7 @@ internal fun <T : Any, A : Field<T>> LinearSpace<T, A>.solve(
* Produce a generic solver based on LUP decomposition * Produce a generic solver based on LUP decomposition
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public fun <T : Comparable<T>, F : Field<T>> LinearSpace<T, F>.lupSolver( public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
singularityCheck: (T) -> Boolean, singularityCheck: (T) -> Boolean,
): LinearSolver<T> = object : LinearSolver<T> { ): LinearSolver<T> = object : LinearSolver<T> {
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> { override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {

View File

@ -35,7 +35,7 @@ public val <T : Any> Matrix<T>.origin: Matrix<T>
/** /**
* Add a single feature to a [Matrix] * Add a single feature to a [Matrix]
*/ */
public fun <T : Any, A : Attribute<T>> Matrix<T>.withAttribute( public fun <T, A : Attribute<T>> Matrix<T>.withAttribute(
attribute: A, attribute: A,
attrValue: T, attrValue: T,
): MatrixWrapper<T> = if (this is MatrixWrapper) { ): MatrixWrapper<T> = if (this is MatrixWrapper) {
@ -44,7 +44,7 @@ public fun <T : Any, A : Attribute<T>> Matrix<T>.withAttribute(
MatrixWrapper(this, Attributes(attribute, attrValue)) MatrixWrapper(this, Attributes(attribute, attrValue))
} }
public fun <T : Any, A : Attribute<Unit>> Matrix<T>.withAttribute( public fun <T, A : Attribute<Unit>> Matrix<T>.withAttribute(
attribute: A, attribute: A,
): MatrixWrapper<T> = if (this is MatrixWrapper) { ): MatrixWrapper<T> = if (this is MatrixWrapper) {
MatrixWrapper(origin, attributes.withAttribute(attribute)) MatrixWrapper(origin, attributes.withAttribute(attribute))
@ -55,7 +55,7 @@ public fun <T : Any, A : Attribute<Unit>> Matrix<T>.withAttribute(
/** /**
* Modify matrix attributes * Modify matrix attributes
*/ */
public fun <T : Any> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attributes): MatrixWrapper<T> = public fun <T> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attributes): MatrixWrapper<T> =
if (this is MatrixWrapper) { if (this is MatrixWrapper) {
MatrixWrapper(origin, modifier(attributes)) MatrixWrapper(origin, modifier(attributes))
} else { } else {
@ -65,7 +65,7 @@ public fun <T : Any> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attrib
/** /**
* Diagonal matrix of ones. The matrix is virtual, no actual matrix is created. * Diagonal matrix of ones. The matrix is virtual, no actual matrix is created.
*/ */
public fun <T : Any> LinearSpace<T, Ring<T>>.one( public fun <T> LinearSpace<T, Ring<T>>.one(
rows: Int, rows: Int,
columns: Int, columns: Int,
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { i, j -> ): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { i, j ->
@ -76,7 +76,7 @@ public fun <T : Any> LinearSpace<T, Ring<T>>.one(
/** /**
* A virtual matrix of zeroes * A virtual matrix of zeroes
*/ */
public fun <T : Any> LinearSpace<T, Ring<T>>.zero( public fun <T> LinearSpace<T, Ring<T>>.zero(
rows: Int, rows: Int,
columns: Int, columns: Int,
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { _, _ -> ): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { _, _ ->

View File

@ -10,6 +10,13 @@ package space.kscience.kmath.linear
import space.kscience.attributes.* import space.kscience.attributes.*
import space.kscience.kmath.nd.StructureAttribute import space.kscience.kmath.nd.StructureAttribute
/**
* A marker interface for algebras that operate on matrices
* @param T type of matrix element
*/
public interface MatrixScope<T> : AttributeScope<Matrix<T>>, WithType<T>
/** /**
* A marker interface representing some properties of matrices or additional transformations of them. Features are used * A marker interface representing some properties of matrices or additional transformations of them. Features are used
* to optimize matrix operations performance in some cases or retrieve the APIs. * to optimize matrix operations performance in some cases or retrieve the APIs.
@ -38,11 +45,11 @@ public object IsUnit : IsDiagonal
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public class Inverted<T>(type: SafeType<Matrix<T>>) : public class Inverted<T>() :
PolymorphicAttribute<Matrix<T>>(type), PolymorphicAttribute<Matrix<T>>(safeTypeOf()),
MatrixAttribute<Matrix<T>> MatrixAttribute<Matrix<T>>
public val <T> MatrixOperations<T>.Inverted: Inverted<T> get() = Inverted(safeTypeOf()) public val <T> MatrixScope<T>.Inverted: Inverted<T> get() = Inverted()
/** /**
* Matrices with this feature can compute their determinant. * Matrices with this feature can compute their determinant.
@ -53,7 +60,7 @@ public class Determinant<T>(type: SafeType<T>) :
PolymorphicAttribute<T>(type), PolymorphicAttribute<T>(type),
MatrixAttribute<T> MatrixAttribute<T>
public val <T> MatrixOperations<T>.Determinant: Determinant<T> get() = Determinant(type) public val <T> MatrixScope<T>.Determinant: Determinant<T> get() = Determinant(type)
/** /**
* Matrices with this feature are lower triangular ones. * Matrices with this feature are lower triangular ones.
@ -77,11 +84,11 @@ public data class LUDecomposition<T>(val l: Matrix<T>, val u: Matrix<T>)
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public class LuDecompositionAttribute<T>(type: SafeType<LUDecomposition<T>>) : public class LuDecompositionAttribute<T> :
PolymorphicAttribute<LUDecomposition<T>>(type), PolymorphicAttribute<LUDecomposition<T>>(safeTypeOf()),
MatrixAttribute<LUDecomposition<T>> MatrixAttribute<LUDecomposition<T>>
public val <T> MatrixOperations<T>.LU: LuDecompositionAttribute<T> get() = LuDecompositionAttribute(safeTypeOf()) public val <T> MatrixScope<T>.LU: LuDecompositionAttribute<T> get() = LuDecompositionAttribute()
/** /**
@ -108,12 +115,12 @@ public interface QRDecomposition<out T> {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public class QRDecompositionAttribute<T>(type: SafeType<QRDecomposition<T>>) : public class QRDecompositionAttribute<T>() :
PolymorphicAttribute<QRDecomposition<T>>(type), PolymorphicAttribute<QRDecomposition<T>>(safeTypeOf()),
MatrixAttribute<QRDecomposition<T>> MatrixAttribute<QRDecomposition<T>>
public val <T> MatrixOperations<T>.QR: QRDecompositionAttribute<T> public val <T> MatrixScope<T>.QR: QRDecompositionAttribute<T>
get() = QRDecompositionAttribute(safeTypeOf()) get() = QRDecompositionAttribute()
public interface CholeskyDecomposition<T> { public interface CholeskyDecomposition<T> {
/** /**
@ -128,12 +135,12 @@ public interface CholeskyDecomposition<T> {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public class CholeskyDecompositionAttribute<T>(type: SafeType<CholeskyDecomposition<T>>) : public class CholeskyDecompositionAttribute<T> :
PolymorphicAttribute<CholeskyDecomposition<T>>(type), PolymorphicAttribute<CholeskyDecomposition<T>>(safeTypeOf()),
MatrixAttribute<CholeskyDecomposition<T>> MatrixAttribute<CholeskyDecomposition<T>>
public val <T> MatrixOperations<T>.Cholesky: CholeskyDecompositionAttribute<T> public val <T> MatrixScope<T>.Cholesky: CholeskyDecompositionAttribute<T>
get() = CholeskyDecompositionAttribute(safeTypeOf()) get() = CholeskyDecompositionAttribute()
public interface SingularValueDecomposition<T> { public interface SingularValueDecomposition<T> {
/** /**
@ -163,12 +170,11 @@ public interface SingularValueDecomposition<T> {
* *
* @param T the type of matrices' items. * @param T the type of matrices' items.
*/ */
public class SVDAttribute<T>(type: SafeType<SingularValueDecomposition<T>>) : public class SVDAttribute<T>() :
PolymorphicAttribute<SingularValueDecomposition<T>>(type), PolymorphicAttribute<SingularValueDecomposition<T>>(safeTypeOf()),
MatrixAttribute<SingularValueDecomposition<T>> MatrixAttribute<SingularValueDecomposition<T>>
public val <T> MatrixOperations<T>.SVD: SVDAttribute<T> public val <T> MatrixScope<T>.SVD: SVDAttribute<T> get() = SVDAttribute()
get() = SVDAttribute(safeTypeOf())
//TODO add sparse matrix feature //TODO add sparse matrix feature

View File

@ -39,5 +39,5 @@ public abstract class EjmlLinearSpace<T : Any, out A : Ring<T>, out M : org.ejml
@UnstableKMathAPI @UnstableKMathAPI
public fun EjmlMatrix<T, *>.inverted(): Matrix<Double> = public fun EjmlMatrix<T, *>.inverted(): Matrix<Double> =
attributeForOrNull(this, Float64Field.linearSpace.Inverted) computeAttribute(this, Float64Field.linearSpace.Inverted)!!
} }

View File

@ -19,9 +19,13 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.linear.* import space.kscience.kmath.linear.*
import space.kscience.kmath.linear.Matrix import space.kscience.kmath.linear.Matrix
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureAttribute
import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.operations.Float32Field import space.kscience.kmath.operations.Float32Field
import space.kscience.kmath.operations.Float64Field import space.kscience.kmath.operations.Float64Field
@ -39,6 +43,8 @@ public class EjmlDoubleVector<out M : DMatrix>(override val origin: M) : EjmlVec
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
} }
override val type: SafeType<Double> get() = safeTypeOf()
override operator fun get(index: Int): Double = origin[0, index] override operator fun get(index: Int): Double = origin[0, index]
} }
@ -50,6 +56,8 @@ public class EjmlFloatVector<out M : FMatrix>(override val origin: M) : EjmlVect
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" } require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
} }
override val type: SafeType<Float> get() = safeTypeOf()
override operator fun get(index: Int): Float = origin[0, index] override operator fun get(index: Int): Float = origin[0, index]
} }
@ -57,6 +65,8 @@ public class EjmlFloatVector<out M : FMatrix>(override val origin: M) : EjmlVect
* [EjmlMatrix] specialization for [Double]. * [EjmlMatrix] specialization for [Double].
*/ */
public class EjmlDoubleMatrix<out M : DMatrix>(override val origin: M) : EjmlMatrix<Double, M>(origin) { public class EjmlDoubleMatrix<out M : DMatrix>(override val origin: M) : EjmlMatrix<Double, M>(origin) {
override val type: SafeType<Double> get() = safeTypeOf()
override operator fun get(i: Int, j: Int): Double = origin[i, j] override operator fun get(i: Int, j: Int): Double = origin[i, j]
} }
@ -64,9 +74,12 @@ public class EjmlDoubleMatrix<out M : DMatrix>(override val origin: M) : EjmlMat
* [EjmlMatrix] specialization for [Float]. * [EjmlMatrix] specialization for [Float].
*/ */
public class EjmlFloatMatrix<out M : FMatrix>(override val origin: M) : EjmlMatrix<Float, M>(origin) { public class EjmlFloatMatrix<out M : FMatrix>(override val origin: M) : EjmlMatrix<Float, M>(origin) {
override val type: SafeType<Float> get() = safeTypeOf()
override operator fun get(i: Int, j: Int): Float = origin[i, j] override operator fun get(i: Int, j: Int): Float = origin[i, j]
} }
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and * [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and
* [DMatrixRMaj] matrices. * [DMatrixRMaj] matrices.
@ -77,7 +90,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
*/ */
override val elementAlgebra: Float64Field get() = Float64Field override val elementAlgebra: Float64Field get() = Float64Field
override val elementType: KType get() = typeOf<Double>() override val type: SafeType<Double> get() = safeTypeOf()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixRMaj> = when { override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixRMaj> = when {
@ -205,6 +218,18 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this
override fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<Double>, attribute: A): V? {
val origin = structure.toEjml().origin
return when(attribute){
Inverted -> {
val res = origin.copy()
CommonOps_DDRM.invert(res)
res.wrapMatrix()
}
else->
}
}
@UnstableKMathAPI @UnstableKMathAPI
override fun <F : StructureFeature> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? { override fun <F : StructureFeature> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
structure.getFeature(type)?.let { return it } structure.getFeature(type)?.let { return it }
@ -305,6 +330,8 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
} }
} }
import org.checkerframework.checker.guieffect.qual.SafeType
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and * [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and
* [FMatrixRMaj] matrices. * [FMatrixRMaj] matrices.
@ -315,7 +342,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, Float32Field, FMatrix
*/ */
override val elementAlgebra: Float32Field get() = Float32Field override val elementAlgebra: Float32Field get() = Float32Field
override val elementType: KType get() = typeOf<Float>() override val type: SafeType<Float> get() = safeTypeOf()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixRMaj> = when { override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixRMaj> = when {
@ -543,6 +570,8 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, Float32Field, FMatrix
} }
} }
import org.checkerframework.checker.guieffect.qual.SafeType
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and * [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and
* [DMatrixSparseCSC] matrices. * [DMatrixSparseCSC] matrices.
@ -553,7 +582,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, Float64Field, DMatri
*/ */
override val elementAlgebra: Float64Field get() = Float64Field override val elementAlgebra: Float64Field get() = Float64Field
override val elementType: KType get() = typeOf<Double>() override val type: SafeType<Double> get() = safeTypeOf()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixSparseCSC> = when { override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixSparseCSC> = when {
@ -776,6 +805,8 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, Float64Field, DMatri
} }
} }
import org.checkerframework.checker.guieffect.qual.SafeType
/** /**
* [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and * [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and
* [FMatrixSparseCSC] matrices. * [FMatrixSparseCSC] matrices.
@ -786,7 +817,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace<Float, Float32Field, FMatrix
*/ */
override val elementAlgebra: Float32Field get() = Float32Field override val elementAlgebra: Float32Field get() = Float32Field
override val elementType: KType get() = typeOf<Float>() override val type: SafeType<Float> get() = safeTypeOf()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixSparseCSC> = when { override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixSparseCSC> = when {

View File

@ -4,7 +4,7 @@
*/ */
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.attributes.TypedAttributesBuilder import space.kscience.attributes.AttributesBuilder
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.Field import space.kscience.kmath.operations.Field
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -92,7 +92,7 @@ public inline fun <reified T : Any> GaussIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
order: Int = 10, order: Int = 10,
intervals: Int = 10, intervals: Int = 10,
attributesBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit, attributesBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> { ): UnivariateIntegrand<T> {
require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" } require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" }

View File

@ -30,7 +30,7 @@ public sealed class IntegrandValue<T> private constructor(): IntegrandAttribute<
} }
} }
public fun <T> TypedAttributesBuilder<Integrand<T>>.value(value: T) { public fun <T> AttributesBuilder<Integrand<T>>.value(value: T) {
IntegrandValue.forType<T>().invoke(value) IntegrandValue.forType<T>().invoke(value)
} }

View File

@ -25,10 +25,10 @@ public fun <T, A : Any> MultivariateIntegrand<T>.withAttribute(
): MultivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value)) ): MultivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
public fun <T> MultivariateIntegrand<T>.withAttributes( public fun <T> MultivariateIntegrand<T>.withAttributes(
block: TypedAttributesBuilder<MultivariateIntegrand<T>>.() -> Unit, block: AttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
): MultivariateIntegrand<T> = withAttributes(attributes.modify(block)) ): MultivariateIntegrand<T> = withAttributes(attributes.modify(block))
public inline fun <reified T : Any> MultivariateIntegrand( public inline fun <reified T : Any> MultivariateIntegrand(
attributeBuilder: TypedAttributesBuilder<MultivariateIntegrand<T>>.() -> Unit, attributeBuilder: AttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
noinline function: (Point<T>) -> T, noinline function: (Point<T>) -> T,
): MultivariateIntegrand<T> = MultivariateIntegrand(safeTypeOf<T>(), Attributes(attributeBuilder), function) ): MultivariateIntegrand<T> = MultivariateIntegrand(safeTypeOf<T>(), Attributes(attributeBuilder), function)

View File

@ -26,11 +26,11 @@ public fun <T, A : Any> UnivariateIntegrand<T>.withAttribute(
): UnivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value)) ): UnivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
public fun <T> UnivariateIntegrand<T>.withAttributes( public fun <T> UnivariateIntegrand<T>.withAttributes(
block: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit, block: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
): UnivariateIntegrand<T> = withAttributes(attributes.modify(block)) ): UnivariateIntegrand<T> = withAttributes(attributes.modify(block))
public inline fun <reified T : Any> UnivariateIntegrand( public inline fun <reified T : Any> UnivariateIntegrand(
attributeBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit, attributeBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function) ): UnivariateIntegrand<T> = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function)
@ -58,7 +58,7 @@ public class UnivariateIntegrandRanges(public val ranges: List<Pair<ClosedRange<
public object UnivariateIntegrationNodes : IntegrandAttribute<Buffer<Double>> public object UnivariateIntegrationNodes : IntegrandAttribute<Buffer<Double>>
public fun TypedAttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(vararg nodes: Double) { public fun AttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(vararg nodes: Double) {
UnivariateIntegrationNodes(Float64Buffer(nodes)) UnivariateIntegrationNodes(Float64Buffer(nodes))
} }
@ -68,7 +68,7 @@ public fun TypedAttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(varar
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate( public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
attributesBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit, attributesBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(attributesBuilder, function)) ): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(attributesBuilder, function))
@ -79,7 +79,7 @@ public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate( public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
attributeBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit = {}, attributeBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit = {},
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> { ): UnivariateIntegrand<T> {

View File

@ -14,7 +14,7 @@ repositories {
} }
readme { readme {
maturity = space.kscience.gradle.Maturity.PROTOTYPE maturity = space.kscience.gradle.Maturity.DEPRECATED
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature("jafama-double", "src/main/kotlin/space/kscience/kmath/jafama/") { feature("jafama-double", "src/main/kotlin/space/kscience/kmath/jafama/") {

View File

@ -7,16 +7,17 @@ package space.kscience.kmath.jafama
import net.jafama.FastMath import net.jafama.FastMath
import net.jafama.StrictFastMath import net.jafama.StrictFastMath
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.Norm import space.kscience.kmath.structures.MutableBufferFactory
import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.ScaleOperations
/** /**
* A field for [Double] (using FastMath) without boxing. Does not produce appropriate field element. * A field for [Double] (using FastMath) without boxing. Does not produce appropriate field element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> { public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
override val bufferFactory: MutableBufferFactory<Double> get() = DoubleField.bufferFactory
override inline val zero: Double get() = 0.0 override inline val zero: Double get() = 0.0
override inline val one: Double get() = 1.0 override inline val one: Double get() = 1.0
@ -68,6 +69,9 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> { public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
override val bufferFactory: MutableBufferFactory<Double> get() = DoubleField.bufferFactory
override inline val zero: Double get() = 0.0 override inline val zero: Double get() = 0.0
override inline val one: Double get() = 1.0 override inline val one: Double get() = 1.0

View File

@ -20,7 +20,7 @@ public class MultikDoubleAlgebra(
) : MultikDivisionTensorAlgebra<Double, Float64Field>(multikEngine), ) : MultikDivisionTensorAlgebra<Double, Float64Field>(multikEngine),
TrigonometricOperations<StructureND<Double>>, ExponentialOperations<StructureND<Double>> { TrigonometricOperations<StructureND<Double>>, ExponentialOperations<StructureND<Double>> {
override val elementAlgebra: Float64Field get() = Float64Field override val elementAlgebra: Float64Field get() = Float64Field
override val type: DataType get() = DataType.DoubleDataType override val dataType: DataType get() = DataType.DoubleDataType
override fun sin(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.sin(arg.asMultik().array).wrap() override fun sin(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.sin(arg.asMultik().array).wrap()

View File

@ -15,7 +15,7 @@ public class MultikFloatAlgebra(
multikEngine: Engine multikEngine: Engine
) : MultikDivisionTensorAlgebra<Float, Float32Field>(multikEngine) { ) : MultikDivisionTensorAlgebra<Float, Float32Field>(multikEngine) {
override val elementAlgebra: Float32Field get() = Float32Field override val elementAlgebra: Float32Field get() = Float32Field
override val type: DataType get() = DataType.FloatDataType override val dataType: DataType get() = DataType.FloatDataType
override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap() override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap()
} }

View File

@ -15,7 +15,7 @@ public class MultikIntAlgebra(
multikEngine: Engine multikEngine: Engine
) : MultikTensorAlgebra<Int, Int32Ring>(multikEngine) { ) : MultikTensorAlgebra<Int, Int32Ring>(multikEngine) {
override val elementAlgebra: Int32Ring get() = Int32Ring override val elementAlgebra: Int32Ring get() = Int32Ring
override val type: DataType get() = DataType.IntDataType override val dataType: DataType get() = DataType.IntDataType
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap() override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
} }

View File

@ -15,7 +15,7 @@ public class MultikLongAlgebra(
multikEngine: Engine multikEngine: Engine
) : MultikTensorAlgebra<Long, Int64Ring>(multikEngine) { ) : MultikTensorAlgebra<Long, Int64Ring>(multikEngine) {
override val elementAlgebra: Int64Ring get() = Int64Ring override val elementAlgebra: Int64Ring get() = Int64Ring
override val type: DataType get() = DataType.LongDataType override val dataType: DataType get() = DataType.LongDataType
override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap() override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap()
} }

View File

@ -15,7 +15,7 @@ public class MultikShortAlgebra(
multikEngine: Engine multikEngine: Engine
) : MultikTensorAlgebra<Short, Int16Ring>(multikEngine) { ) : MultikTensorAlgebra<Short, Int16Ring>(multikEngine) {
override val elementAlgebra: Int16Ring get() = Int16Ring override val elementAlgebra: Int16Ring get() = Int16Ring
override val type: DataType get() = DataType.ShortDataType override val dataType: DataType get() = DataType.ShortDataType
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap() override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
} }

View File

@ -6,13 +6,33 @@
package space.kscience.kmath.multik package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.data.*
import space.kscience.attributes.SafeType
import space.kscience.attributes.safeTypeOf
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
public val DataType.type: SafeType<*>
get() = when (this) {
DataType.ByteDataType -> ByteRing.type
DataType.ShortDataType -> ShortRing.type
DataType.IntDataType -> IntRing.type
DataType.LongDataType -> LongRing.type
DataType.FloatDataType -> Float32Field.type
DataType.DoubleDataType -> Float64Field.type
DataType.ComplexFloatDataType -> safeTypeOf<Pair<Float, Float>>()
DataType.ComplexDoubleDataType -> ComplexField.type
}
@JvmInline @JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> { public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
@Suppress("UNCHECKED_CAST")
override val type: SafeType<T> get() = array.dtype.type as SafeType<T>
override val shape: ShapeND get() = ShapeND(array.shape) override val shape: ShapeND get() = ShapeND(array.shape)
@PerformancePitfall @PerformancePitfall

View File

@ -26,7 +26,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
private val multikEngine: Engine, private val multikEngine: Engine,
) : TensorAlgebra<T, A> where T : Number, T : Comparable<T> { ) : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
public abstract val type: DataType public abstract val dataType: DataType
protected val multikMath: Math = multikEngine.getMath() protected val multikMath: Math = multikEngine.getMath()
protected val multikLinAl: LinAlg = multikEngine.getLinAlg() protected val multikLinAl: LinAlg = multikEngine.getLinAlg()
@ -35,7 +35,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
@OptIn(UnsafeKMathAPI::class) @OptIn(UnsafeKMathAPI::class)
override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor<T> { override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor<T> {
val strides = ColumnStrides(shape) val strides = ColumnStrides(shape)
val memoryView = initMemoryView<T>(strides.linearSize, type) val memoryView = initMemoryView<T>(strides.linearSize, dataType)
strides.asSequence().forEachIndexed { linearIndex, tensorIndex -> strides.asSequence().forEachIndexed { linearIndex, tensorIndex ->
memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex) memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex)
} }
@ -44,7 +44,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
@OptIn(PerformancePitfall::class, UnsafeKMathAPI::class) @OptIn(PerformancePitfall::class, UnsafeKMathAPI::class)
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) { override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
val data = initMemoryView<T>(array.size, type) val data = initMemoryView<T>(array.size, dataType)
var count = 0 var count = 0
for (el in array) data[count++] = elementAlgebra.transform(el) for (el in array) data[count++] = elementAlgebra.transform(el)
NDArray(data, shape = shape.asArray(), dim = array.dim).wrap() NDArray(data, shape = shape.asArray(), dim = array.dim).wrap()
@ -58,7 +58,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> = override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
if (this is MultikTensor) { if (this is MultikTensor) {
val array = asMultik().array val array = asMultik().array
val data = initMemoryView<T>(array.size, type) val data = initMemoryView<T>(array.size, dataType)
val indexIter = array.multiIndices.iterator() val indexIter = array.multiIndices.iterator()
var index = 0 var index = 0
for (item in array) { for (item in array) {
@ -95,7 +95,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array val leftArray = left.asMultik().array
val rightArray = right.asMultik().array val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, type) val data = initMemoryView<T>(leftArray.size, dataType)
var counter = 0 var counter = 0
val leftIterator = leftArray.iterator() val leftIterator = leftArray.iterator()
val rightIterator = rightArray.iterator() val rightIterator = rightArray.iterator()
@ -114,7 +114,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) { public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
this this
} else { } else {
val res = mk.zeros<T, DN>(shape.asArray(), type).asDNArray() val res = mk.zeros<T, DN>(shape.asArray(), dataType).asDNArray()
for (index in res.multiIndices) { for (index in res.multiIndices) {
res[index] = this[index] res[index] = this[index]
} }
@ -296,7 +296,7 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>(
@OptIn(UnsafeKMathAPI::class) @OptIn(UnsafeKMathAPI::class)
override fun T.div(arg: StructureND<T>): MultikTensor<T> = override fun T.div(arg: StructureND<T>): MultikTensor<T> =
Multik.ones<T, DN>(arg.shape.asArray(), type).apply { divAssign(arg.asMultik().array) }.wrap() Multik.ones<T, DN>(arg.shape.asArray(), dataType).apply { divAssign(arg.asMultik().array) }.wrap()
override fun StructureND<T>.div(arg: T): MultikTensor<T> = override fun StructureND<T>.div(arg: T): MultikTensor<T> =
asMultik().array.div(arg).wrap() asMultik().array.div(arg).wrap()

View File

@ -9,20 +9,21 @@ import space.kscience.attributes.*
import space.kscience.kmath.expressions.DifferentiableExpression import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
public class OptimizationValue<T>(public val value: T) : OptimizationFeature { public class OptimizationValue<V>(type: SafeType<V>) : PolymorphicAttribute<V>(type)
override fun toString(): String = "Value($value)"
}
public enum class FunctionOptimizationTarget { public enum class OptimizationDirection {
MAXIMIZE, MAXIMIZE,
MINIMIZE MINIMIZE
} }
public object FunctionOptimizationTarget: OptimizationAttribute<OptimizationDirection>
public class FunctionOptimization<T>( public class FunctionOptimization<T>(
override val attributes: Attributes,
public val expression: DifferentiableExpression<T>, public val expression: DifferentiableExpression<T>,
override val attributes: Attributes,
) : OptimizationProblem<T> { ) : OptimizationProblem<T> {
override val type: SafeType<T> get() = expression.type
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true
@ -47,36 +48,52 @@ public class FunctionOptimization<T>(
public companion object public companion object
} }
public fun <T> FunctionOptimization(
expression: DifferentiableExpression<T>,
attributeBuilder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit,
): FunctionOptimization<T> = FunctionOptimization(expression, Attributes(attributeBuilder))
public class OptimizationPrior<T> :
public class OptimizationPrior<T>(type: SafeType<T>):
PolymorphicAttribute<DifferentiableExpression<T>>(safeTypeOf()), PolymorphicAttribute<DifferentiableExpression<T>>(safeTypeOf()),
Attribute<DifferentiableExpression<T>> Attribute<DifferentiableExpression<T>>
//public val <T> FunctionOptimization.Companion.Optimization get() = public fun <T> FunctionOptimization<T>.withAttributes(
modifier: AttributesBuilder<FunctionOptimization<T>>.() -> Unit,
public fun <T> FunctionOptimization<T>.withFeatures(
vararg newFeature: OptimizationFeature,
): FunctionOptimization<T> = FunctionOptimization( ): FunctionOptimization<T> = FunctionOptimization(
attributes.with(*newFeature),
expression, expression,
attributes.modify(modifier),
) )
/** /**
* Optimizes differentiable expression using specific [optimizer] form given [startingPoint]. * Optimizes differentiable expression using specific [optimizer] form given [startingPoint].
*/ */
public suspend fun <T : Any> DifferentiableExpression<T>.optimizeWith( public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
optimizer: Optimizer<T, FunctionOptimization<T>>, optimizer: Optimizer<T, FunctionOptimization<T>>,
startingPoint: Map<Symbol, T>, startingPoint: Map<Symbol, T>,
vararg features: OptimizationFeature, modifier: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
): FunctionOptimization<T> { ): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this) val problem = FunctionOptimization(this){
startAt(startingPoint)
modifier()
}
return optimizer.optimize(problem) return optimizer.optimize(problem)
} }
public val <T> FunctionOptimization<T>.resultValueOrNull: T? public val <T> FunctionOptimization<T>.resultValueOrNull: T?
get() = getFeature<OptimizationResult<T>>()?.point?.let { expression(it) } get() = attributes[OptimizationResult<T>()]?.let { expression(it) }
public val <T> FunctionOptimization<T>.resultValue: T public val <T> FunctionOptimization<T>.resultValue: T
get() = resultValueOrNull ?: error("Result is not present in $this") get() = resultValueOrNull ?: error("Result is not present in $this")
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
optimizer: Optimizer<T, FunctionOptimization<T>>,
vararg startingPoint: Pair<Symbol, T>,
builder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(this) {
startAt(mapOf(*startingPoint))
builder()
}
return optimizer.optimize(problem)
}

View File

@ -1,96 +0,0 @@
/*
* Copyright 2018-2022 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.optimization
import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.data.XYColumnarData
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.FeatureSet
public abstract class OptimizationBuilder<T, R : OptimizationProblem<T>> {
public val features: MutableList<OptimizationFeature> = ArrayList()
public fun addFeature(feature: OptimizationFeature) {
features.add(feature)
}
public inline fun <reified T : OptimizationFeature> updateFeature(update: (T?) -> T) {
val existing = features.find { it.key == T::class } as? T
val new = update(existing)
if (existing != null) {
features.remove(existing)
}
addFeature(new)
}
public abstract fun build(): R
}
public fun <T> OptimizationBuilder<T, *>.startAt(startingPoint: Map<Symbol, T>) {
addFeature(OptimizationStartPoint(startingPoint))
}
public class FunctionOptimizationBuilder<T>(
private val expression: DifferentiableExpression<T>,
) : OptimizationBuilder<T, FunctionOptimization<T>>() {
override fun build(): FunctionOptimization<T> = FunctionOptimization(FeatureSet.of(features), expression)
}
public fun <T> FunctionOptimization(
expression: DifferentiableExpression<T>,
builder: FunctionOptimizationBuilder<T>.() -> Unit,
): FunctionOptimization<T> = FunctionOptimizationBuilder(expression).apply(builder).build()
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
optimizer: Optimizer<T, FunctionOptimization<T>>,
startingPoint: Map<Symbol, T>,
builder: FunctionOptimizationBuilder<T>.() -> Unit = {},
): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(this) {
startAt(startingPoint)
builder()
}
return optimizer.optimize(problem)
}
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
optimizer: Optimizer<T, FunctionOptimization<T>>,
vararg startingPoint: Pair<Symbol, T>,
builder: FunctionOptimizationBuilder<T>.() -> Unit = {},
): FunctionOptimization<T> {
val problem = FunctionOptimization<T>(this) {
startAt(mapOf(*startingPoint))
builder()
}
return optimizer.optimize(problem)
}
@OptIn(UnstableKMathAPI::class)
public class XYOptimizationBuilder(
public val data: XYColumnarData<Double, Double, Double>,
public val model: DifferentiableExpression<Double>,
) : OptimizationBuilder<Double, XYFit>() {
public var pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY
public var pointWeight: PointWeight = PointWeight.byYSigma
override fun build(): XYFit = XYFit(
data,
model,
FeatureSet.of(features),
pointToCurveDistance,
pointWeight
)
}
@OptIn(UnstableKMathAPI::class)
public fun XYOptimization(
data: XYColumnarData<Double, Double, Double>,
model: DifferentiableExpression<Double>,
builder: XYOptimizationBuilder.() -> Unit,
): XYFit = XYOptimizationBuilder(data, model).apply(builder).build()

View File

@ -6,64 +6,53 @@
package space.kscience.kmath.optimization package space.kscience.kmath.optimization
import space.kscience.attributes.* import space.kscience.attributes.*
import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.NamedMatrix import space.kscience.kmath.expressions.NamedMatrix
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.misc.* import space.kscience.kmath.misc.Loggable
import kotlin.reflect.KClass
public interface OptimizationAttribute<T> : Attribute<T> public interface OptimizationAttribute<T> : Attribute<T>
public interface OptimizationProblem<T> : AttributeContainer public interface OptimizationProblem<T> : AttributeContainer, WithType<T>
public inline fun <reified F : OptimizationFeature> OptimizationProblem<*>.getFeature(): F? = getFeature(F::class) public class OptimizationStartPoint<T> : OptimizationAttribute<Map<Symbol, T>>,
PolymorphicAttribute<Map<Symbol, T>>(safeTypeOf())
public open class OptimizationStartPoint<T>(public val point: Map<Symbol, T>) : OptimizationFeature {
override fun toString(): String = "StartPoint($point)"
}
/**
* Covariance matrix for
*/
public class OptimizationCovariance<T>(public val covariance: NamedMatrix<T>) : OptimizationFeature {
override fun toString(): String = "Covariance($covariance)"
}
/** /**
* Get the starting point for optimization. Throws error if not defined. * Get the starting point for optimization. Throws error if not defined.
*/ */
public val <T> OptimizationProblem<T>.startPoint: Map<Symbol, T> public val <T> OptimizationProblem<T>.startPoint: Map<Symbol, T>
get() = getFeature<OptimizationStartPoint<T>>()?.point get() = attributes[OptimizationStartPoint()] ?: error("Starting point not defined in $this")
?: error("Starting point not defined in $this")
public open class OptimizationResult<T>(public val point: Map<Symbol, T>) : OptimizationFeature { public fun <T> AttributesBuilder<OptimizationProblem<T>>.startAt(startingPoint: Map<Symbol, T>) {
override fun toString(): String = "Result($point)" set(::OptimizationStartPoint, startingPoint)
} }
public val <T> OptimizationProblem<T>.resultPointOrNull: Map<Symbol, T>?
get() = getFeature<OptimizationResult<T>>()?.point
public val <T> OptimizationProblem<T>.resultPoint: Map<Symbol, T> /**
get() = resultPointOrNull ?: error("Result is not present in $this") * Covariance matrix for optimization
*/
public class OptimizationCovariance<T> : OptimizationAttribute<NamedMatrix<T>>,
PolymorphicAttribute<NamedMatrix<T>>(safeTypeOf())
public class OptimizationLog(private val loggable: Loggable) : Loggable by loggable, OptimizationFeature {
override fun toString(): String = "Log($loggable)" public class OptimizationResult<T>() : OptimizationAttribute<Map<Symbol, T>>,
} PolymorphicAttribute<Map<Symbol, T>>(safeTypeOf())
public val <T> OptimizationProblem<T>.resultOrNull: Map<Symbol, T>? get() = attributes[OptimizationResult()]
public val <T> OptimizationProblem<T>.result: Map<Symbol, T>
get() = resultOrNull ?: error("Result is not present in $this")
public object OptimizationLog : OptimizationAttribute<Loggable>
/** /**
* Free parameters of the optimization * Free parameters of the optimization
*/ */
public class OptimizationParameters(public val symbols: List<Symbol>) : OptimizationFeature { public object OptimizationParameters : OptimizationAttribute<List<Symbol>>
public constructor(vararg symbols: Symbol) : this(listOf(*symbols))
override fun toString(): String = "Parameters($symbols)"
}
/** /**
* Maximum allowed number of iterations * Maximum allowed number of iterations
*/ */
public class OptimizationIterations(public val maxIterations: Int) : OptimizationFeature { public object OptimizationIterations : OptimizationAttribute<Int>
override fun toString(): String = "Iterations($maxIterations)"
}

View File

@ -16,13 +16,7 @@ import space.kscience.kmath.structures.Float64Buffer
import kotlin.math.abs import kotlin.math.abs
public class QowRuns(public val runs: Int) : OptimizationFeature { public object QowRuns: OptimizationAttribute<Int>
init {
require(runs >= 1) { "Number of runs must be more than zero" }
}
override fun toString(): String = "QowRuns(runs=$runs)"
}
/** /**
@ -69,7 +63,7 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
} }
val prior: DifferentiableExpression<Double>? val prior: DifferentiableExpression<Double>?
get() = problem.getFeature<OptimizationPrior<Double>>()?.withDefaultArgs(allParameters) get() = problem.attributes[OptimizationPrior<Double>()]?.withDefaultArgs(allParameters)
override fun toString(): String = freeParameters.toString() override fun toString(): String = freeParameters.toString()
} }
@ -176,7 +170,7 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
fast: Boolean = false, fast: Boolean = false,
): QoWeight { ): QoWeight {
val logger = problem.getFeature<OptimizationLog>() val logger = problem.attributes[OptimizationLog]
var dis: Double //discrepancy value var dis: Double //discrepancy value
@ -231,7 +225,7 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
} }
private fun QoWeight.covariance(): NamedMatrix<Double> { private fun QoWeight.covariance(): NamedMatrix<Double> {
val logger = problem.getFeature<OptimizationLog>() val logger = problem.attributes[OptimizationLog]
logger?.log { logger?.log {
""" """
@ -257,11 +251,11 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
} }
override suspend fun optimize(problem: XYFit): XYFit { override suspend fun optimize(problem: XYFit): XYFit {
val qowRuns = problem.getFeature<QowRuns>()?.runs ?: 2 val qowRuns = problem.attributes[QowRuns] ?: 2
val iterations = problem.getFeature<OptimizationIterations>()?.maxIterations ?: 50 val iterations = problem.attributes[OptimizationIterations] ?: 50
val freeParameters: Map<Symbol, Double> = problem.getFeature<OptimizationParameters>()?.let { op -> val freeParameters: Map<Symbol, Double> = problem.attributes[OptimizationParameters]?.let { symbols ->
problem.startPoint.filterKeys { it in op.symbols } problem.startPoint.filterKeys { it in symbols }
} ?: problem.startPoint } ?: problem.startPoint
var qow = QoWeight(problem, freeParameters) var qow = QoWeight(problem, freeParameters)
@ -271,6 +265,9 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
res = qow.newtonianRun(maxSteps = iterations) res = qow.newtonianRun(maxSteps = iterations)
} }
val covariance = res.covariance() val covariance = res.covariance()
return res.problem.withFeature(OptimizationResult(res.freeParameters), OptimizationCovariance(covariance)) return res.problem.withAttributes {
set(OptimizationResult(), res.freeParameters)
set(OptimizationCovariance(), covariance)
}
} }
} }

View File

@ -2,37 +2,41 @@
* Copyright 2018-2022 KMath contributors. * Copyright 2018-2022 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:OptIn(UnstableKMathAPI::class) @file:OptIn(UnstableKMathAPI::class, UnstableKMathAPI::class)
package space.kscience.kmath.optimization package space.kscience.kmath.optimization
import space.kscience.attributes.*
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.data.XYColumnarData
import space.kscience.kmath.data.indices import space.kscience.kmath.data.indices
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.FeatureSet
import space.kscience.kmath.misc.Loggable import space.kscience.kmath.misc.Loggable
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.Float64Field
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
import kotlin.math.pow import kotlin.math.pow
/** /**
* Specify the way to compute distance from point to the curve as DifferentiableExpression * Specify the way to compute distance from point to the curve as DifferentiableExpression
*/ */
public interface PointToCurveDistance : OptimizationFeature { public interface PointToCurveDistance {
public fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double> public fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double>
public companion object { public companion object : OptimizationAttribute<PointToCurveDistance> {
public val byY: PointToCurveDistance = object : PointToCurveDistance { public val byY: PointToCurveDistance = object : PointToCurveDistance {
override fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double> { override fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double> {
val x = problem.data.x[index] val x = problem.data.x[index]
val y = problem.data.y[index] val y = problem.data.y[index]
return object : DifferentiableExpression<Double> { return object : DifferentiableExpression<Double> {
override val type: SafeType<Double> get() = DoubleField.type
override fun derivativeOrNull( override fun derivativeOrNull(
symbols: List<Symbol>, symbols: List<Symbol>,
): Expression<Double>? = problem.model.derivativeOrNull(symbols)?.let { derivExpression -> ): Expression<Double>? = problem.model.derivativeOrNull(symbols)?.let { derivExpression ->
Expression { arguments -> Expression(DoubleField.type) { arguments ->
derivExpression.invoke(arguments + (Symbol.x to x)) derivExpression.invoke(arguments + (Symbol.x to x))
} }
} }
@ -51,18 +55,21 @@ public interface PointToCurveDistance : OptimizationFeature {
* Compute a wight of the point. The more the weight, the more impact this point will have on the fit. * Compute a wight of the point. The more the weight, the more impact this point will have on the fit.
* By default, uses Dispersion^-1 * By default, uses Dispersion^-1
*/ */
public interface PointWeight : OptimizationFeature { public interface PointWeight {
public fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double> public fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double>
public companion object { public companion object : OptimizationAttribute<PointWeight> {
public fun bySigma(sigmaSymbol: Symbol): PointWeight = object : PointWeight { public fun bySigma(sigmaSymbol: Symbol): PointWeight = object : PointWeight {
override fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double> = override fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double> =
object : DifferentiableExpression<Double> { object : DifferentiableExpression<Double> {
override val type: SafeType<Double> get() = DoubleField.type
override fun invoke(arguments: Map<Symbol, Double>): Double { override fun invoke(arguments: Map<Symbol, Double>): Double {
return problem.data[sigmaSymbol]?.get(index)?.pow(-2) ?: 1.0 return problem.data[sigmaSymbol]?.get(index)?.pow(-2) ?: 1.0
} }
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { 0.0 } override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> =
Expression(DoubleField.type) { 0.0 }
} }
override fun toString(): String = "PointWeightBySigma($sigmaSymbol)" override fun toString(): String = "PointWeightBySigma($sigmaSymbol)"
@ -79,41 +86,52 @@ public interface PointWeight : OptimizationFeature {
public class XYFit( public class XYFit(
public val data: XYColumnarData<Double, Double, Double>, public val data: XYColumnarData<Double, Double, Double>,
public val model: DifferentiableExpression<Double>, public val model: DifferentiableExpression<Double>,
override val attributes: FeatureSet<OptimizationFeature>, override val attributes: Attributes,
internal val pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, internal val pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
internal val pointWeight: PointWeight = PointWeight.byYSigma, internal val pointWeight: PointWeight = PointWeight.byYSigma,
public val xSymbol: Symbol = Symbol.x, public val xSymbol: Symbol = Symbol.x,
) : OptimizationProblem<Double> { ) : OptimizationProblem<Double> {
override val type: SafeType<Double> get() = Float64Field.type
public fun distance(index: Int): DifferentiableExpression<Double> = pointToCurveDistance.distance(this, index) public fun distance(index: Int): DifferentiableExpression<Double> = pointToCurveDistance.distance(this, index)
public fun weight(index: Int): DifferentiableExpression<Double> = pointWeight.weight(this, index) public fun weight(index: Int): DifferentiableExpression<Double> = pointWeight.weight(this, index)
} }
public fun XYFit.withFeature(vararg features: OptimizationFeature): XYFit {
return XYFit(data, model, this.attributes.with(*features), pointToCurveDistance, pointWeight) public fun XYOptimization(
} data: XYColumnarData<Double, Double, Double>,
model: DifferentiableExpression<Double>,
builder: AttributesBuilder<XYFit>.() -> Unit,
): XYFit = XYFit(data, model, Attributes(builder))
public fun XYFit.withAttributes(
modifier: AttributesBuilder<XYFit>.() -> Unit,
): XYFit = XYFit(data, model, attributes.modify(modifier), pointToCurveDistance, pointWeight, xSymbol)
public suspend fun XYColumnarData<Double, Double, Double>.fitWith( public suspend fun XYColumnarData<Double, Double, Double>.fitWith(
optimizer: Optimizer<Double, XYFit>, optimizer: Optimizer<Double, XYFit>,
modelExpression: DifferentiableExpression<Double>, modelExpression: DifferentiableExpression<Double>,
startingPoint: Map<Symbol, Double>, startingPoint: Map<Symbol, Double>,
vararg features: OptimizationFeature = emptyArray(), attributes: Attributes = Attributes.EMPTY,
xSymbol: Symbol = Symbol.x, xSymbol: Symbol = Symbol.x,
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
pointWeight: PointWeight = PointWeight.byYSigma, pointWeight: PointWeight = PointWeight.byYSigma,
): XYFit { ): XYFit {
var actualFeatures = FeatureSet.of(*features, OptimizationStartPoint(startingPoint))
if (actualFeatures.getFeature<OptimizationLog>() == null) {
actualFeatures = actualFeatures.with(OptimizationLog(Loggable.console))
}
val problem = XYFit( val problem = XYFit(
this, this,
modelExpression, modelExpression,
actualFeatures, attributes.modify<XYFit> {
set(::OptimizationStartPoint, startingPoint)
if (!hasAny<OptimizationLog>()) {
set(OptimizationLog, Loggable.console)
}
},
pointToCurveDistance, pointToCurveDistance,
pointWeight, pointWeight,
xSymbol xSymbol,
) )
return optimizer.optimize(problem) return optimizer.optimize(problem)
} }
@ -125,7 +143,7 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
optimizer: Optimizer<Double, XYFit>, optimizer: Optimizer<Double, XYFit>,
processor: AutoDiffProcessor<Double, I, A>, processor: AutoDiffProcessor<Double, I, A>,
startingPoint: Map<Symbol, Double>, startingPoint: Map<Symbol, Double>,
vararg features: OptimizationFeature = emptyArray(), attributes: Attributes = Attributes.EMPTY,
xSymbol: Symbol = Symbol.x, xSymbol: Symbol = Symbol.x,
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY, pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
pointWeight: PointWeight = PointWeight.byYSigma, pointWeight: PointWeight = PointWeight.byYSigma,
@ -140,7 +158,7 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
optimizer = optimizer, optimizer = optimizer,
modelExpression = modelExpression, modelExpression = modelExpression,
startingPoint = startingPoint, startingPoint = startingPoint,
features = features, attributes = attributes,
xSymbol = xSymbol, xSymbol = xSymbol,
pointToCurveDistance = pointToCurveDistance, pointToCurveDistance = pointToCurveDistance,
pointWeight = pointWeight pointWeight = pointWeight
@ -152,7 +170,7 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
*/ */
public val XYFit.chiSquaredOrNull: Double? public val XYFit.chiSquaredOrNull: Double?
get() { get() {
val result = startPoint + (resultPointOrNull ?: return null) val result = startPoint + (resultOrNull ?: return null)
return data.indices.sumOf { index -> return data.indices.sumOf { index ->
@ -167,4 +185,4 @@ public val XYFit.chiSquaredOrNull: Double?
} }
public val XYFit.dof: Int public val XYFit.dof: Int
get() = data.size - (getFeature<OptimizationParameters>()?.symbols?.size ?: startPoint.size) get() = data.size - (attributes[OptimizationParameters]?.size ?: startPoint.size)

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.optimization package space.kscience.kmath.optimization
import space.kscience.attributes.AttributesBuilder
import space.kscience.attributes.SafeType
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.data.XYColumnarData import space.kscience.kmath.data.XYColumnarData
import space.kscience.kmath.data.indices import space.kscience.kmath.data.indices
@ -12,6 +14,7 @@ import space.kscience.kmath.expressions.DifferentiableExpression
import space.kscience.kmath.expressions.Expression import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.derivative import space.kscience.kmath.expressions.derivative
import space.kscience.kmath.operations.Float64Field
import kotlin.math.PI import kotlin.math.PI
import kotlin.math.ln import kotlin.math.ln
import kotlin.math.pow import kotlin.math.pow
@ -22,7 +25,9 @@ private val oneOver2Pi = 1.0 / sqrt(2 * PI)
@UnstableKMathAPI @UnstableKMathAPI
internal fun XYFit.logLikelihood(): DifferentiableExpression<Double> = object : DifferentiableExpression<Double> { internal fun XYFit.logLikelihood(): DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments -> override val type: SafeType<Double> get() = Float64Field.type
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression(type) { arguments ->
data.indices.sumOf { index -> data.indices.sumOf { index ->
val d = distance(index)(arguments) val d = distance(index)(arguments)
val weight = weight(index)(arguments) val weight = weight(index)(arguments)
@ -53,8 +58,12 @@ internal fun XYFit.logLikelihood(): DifferentiableExpression<Double> = object :
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYFit): XYFit { public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYFit): XYFit {
val functionOptimization = FunctionOptimization(problem.attributes, problem.logLikelihood()) val functionOptimization = FunctionOptimization(problem.logLikelihood(), problem.attributes)
val result = optimize(functionOptimization.withFeatures(FunctionOptimizationTarget.MAXIMIZE)) val result = optimize(
functionOptimization.withAttributes {
FunctionOptimizationTarget(OptimizationDirection.MAXIMIZE)
}
)
return XYFit(problem.data,problem.model, result.attributes) return XYFit(problem.data,problem.model, result.attributes)
} }
@ -62,5 +71,5 @@ public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLik
public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood( public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(
data: XYColumnarData<Double, Double, Double>, data: XYColumnarData<Double, Double, Double>,
model: DifferentiableExpression<Double>, model: DifferentiableExpression<Double>,
builder: XYOptimizationBuilder.() -> Unit, builder: AttributesBuilder<XYFit>.() -> Unit,
): XYFit = maximumLogLikelihood(XYOptimization(data, model, builder)) ): XYFit = maximumLogLikelihood(XYOptimization(data, model, builder))

View File

@ -5,10 +5,12 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.nd.MutableStructureNDOfDouble import space.kscience.kmath.nd.MutableStructureNDOfDouble
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.core.internal.toPrettyString import space.kscience.kmath.tensors.core.internal.toPrettyString
@ -88,6 +90,8 @@ public open class DoubleTensor(
final override val source: OffsetDoubleBuffer, final override val source: OffsetDoubleBuffer,
) : BufferedTensor<Double>(shape), MutableStructureNDOfDouble { ) : BufferedTensor<Double>(shape), MutableStructureNDOfDouble {
override val type: SafeType<Double> get() = DoubleField.type
init { init {
require(linearSize == source.size) { "Source buffer size must be equal tensor size" } require(linearSize == source.size) { "Source buffer size must be equal tensor size" }
} }

View File

@ -5,8 +5,10 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.attributes.SafeType
import space.kscience.kmath.PerformancePitfall import space.kscience.kmath.PerformancePitfall
import space.kscience.kmath.nd.ShapeND import space.kscience.kmath.nd.ShapeND
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
/** /**
@ -24,6 +26,8 @@ public class OffsetIntBuffer(
require(offset + size <= source.size) { "Maximum index must be inside source dimension" } require(offset + size <= source.size) { "Maximum index must be inside source dimension" }
} }
override val type: SafeType<Int> get() = IntRing.type
override fun set(index: Int, value: Int) { override fun set(index: Int, value: Int) {
require(index in 0 until size) { "Index must be in [0, size)" } require(index in 0 until size) { "Index must be in [0, size)" }
source[index + offset] = value source[index + offset] = value
@ -83,6 +87,8 @@ public class IntTensor(
require(linearSize == source.size) { "Source buffer size must be equal tensor size" } require(linearSize == source.size) { "Source buffer size must be equal tensor size" }
} }
override val type: SafeType<Int> get() = IntRing.type
public constructor(shape: ShapeND, buffer: Int32Buffer) : this(shape, OffsetIntBuffer(buffer, 0, buffer.size)) public constructor(shape: ShapeND, buffer: Int32Buffer) : this(shape, OffsetIntBuffer(buffer, 0, buffer.size))
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)