0.4 WIP
This commit is contained in:
parent
2f2f552648
commit
5c82a5e1fa
@ -3,7 +3,7 @@
|
|||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- Explicit `SafeType` for algebras and buffers.
|
- Reification. Explicit `SafeType` for algebras and buffers.
|
||||||
- Integer division algebras.
|
- Integer division algebras.
|
||||||
- Float32 geometries.
|
- Float32 geometries.
|
||||||
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
|
- New Attributes-kt module that could be used as stand-alone. It declares. type-safe attributes containers.
|
||||||
@ -16,6 +16,7 @@
|
|||||||
- kmath-geometry is split into `euclidean2d` and `euclidean3d`
|
- kmath-geometry is split into `euclidean2d` and `euclidean3d`
|
||||||
- Features replaced with Attributes.
|
- Features replaced with Attributes.
|
||||||
- Transposed refactored.
|
- Transposed refactored.
|
||||||
|
- Kmath-memory is moved on top of core.
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -24,14 +24,3 @@ public interface AttributeWithDefault<T> : Attribute<T> {
|
|||||||
*/
|
*/
|
||||||
public interface SetAttribute<V> : Attribute<Set<V>>
|
public interface SetAttribute<V> : Attribute<Set<V>>
|
||||||
|
|
||||||
/**
|
|
||||||
* An attribute that has a type parameter for value
|
|
||||||
* @param type parameter-type
|
|
||||||
*/
|
|
||||||
public abstract class PolymorphicAttribute<T>(public val type: SafeType<T>) : Attribute<T> {
|
|
||||||
override fun equals(other: Any?): Boolean = other != null &&
|
|
||||||
(this::class == other::class) &&
|
|
||||||
(other as? PolymorphicAttribute<*>)?.type == this.type
|
|
||||||
|
|
||||||
override fun hashCode(): Int = this::class.hashCode() + type.hashCode()
|
|
||||||
}
|
|
||||||
|
@ -6,8 +6,14 @@
|
|||||||
package space.kscience.attributes
|
package space.kscience.attributes
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A container for attributes. [attributes] could be made mutable by implementation
|
* A container for [Attributes]
|
||||||
*/
|
*/
|
||||||
public interface AttributeContainer {
|
public interface AttributeContainer {
|
||||||
public val attributes: Attributes
|
public val attributes: Attributes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A scope, where attribute keys could be resolved
|
||||||
|
*/
|
||||||
|
public interface AttributeScope<O>
|
||||||
|
|
||||||
|
@ -7,21 +7,27 @@ package space.kscience.attributes
|
|||||||
|
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
@JvmInline
|
/**
|
||||||
public value class Attributes internal constructor(public val content: Map<out Attribute<*>, Any?>) {
|
* A set of attributes. The implementation must guarantee that [content] keys correspond to its value types.
|
||||||
|
*/
|
||||||
|
public interface Attributes {
|
||||||
|
public val content: Map<out Attribute<*>, Any?>
|
||||||
|
|
||||||
public val keys: Set<Attribute<*>> get() = content.keys
|
public val keys: Set<Attribute<*>> get() = content.keys
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public operator fun <T> get(attribute: Attribute<T>): T? = content[attribute] as? T
|
public operator fun <T> get(attribute: Attribute<T>): T? = content[attribute] as? T
|
||||||
|
|
||||||
override fun toString(): String = "Attributes(value=${content.entries})"
|
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val EMPTY: Attributes = Attributes(emptyMap())
|
public val EMPTY: Attributes = AttributesImpl(emptyMap())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@JvmInline
|
||||||
|
internal value class AttributesImpl(override val content: Map<out Attribute<*>, Any?>) : Attributes {
|
||||||
|
override fun toString(): String = "Attributes(value=${content.entries})"
|
||||||
|
}
|
||||||
|
|
||||||
public fun Attributes.isEmpty(): Boolean = content.isEmpty()
|
public fun Attributes.isEmpty(): Boolean = content.isEmpty()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -33,19 +39,19 @@ public fun <T> Attributes.getOrDefault(attribute: AttributeWithDefault<T>): T =
|
|||||||
* Check if there is an attribute that matches given key by type and adheres to [predicate].
|
* Check if there is an attribute that matches given key by type and adheres to [predicate].
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <T, reified A : Attribute<T>> Attributes.any(predicate: (value: T) -> Boolean): Boolean =
|
public inline fun <T, reified A : Attribute<T>> Attributes.hasAny(predicate: (value: T) -> Boolean): Boolean =
|
||||||
content.any { (mapKey, mapValue) -> mapKey is A && predicate(mapValue as T) }
|
content.any { (mapKey, mapValue) -> mapKey is A && predicate(mapValue as T) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if there is an attribute of given type (subtypes included)
|
* Check if there is an attribute of given type (subtypes included)
|
||||||
*/
|
*/
|
||||||
public inline fun <T, reified A : Attribute<T>> Attributes.any(): Boolean =
|
public inline fun <reified A : Attribute<*>> Attributes.hasAny(): Boolean =
|
||||||
content.any { (mapKey, _) -> mapKey is A }
|
content.any { (mapKey, _) -> mapKey is A }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if [Attributes] contains a flag. Multiple keys that are instances of a flag could be present
|
* Check if [Attributes] contains a flag. Multiple keys that are instances of a flag could be present
|
||||||
*/
|
*/
|
||||||
public inline fun <reified A : FlagAttribute> Attributes.has(): Boolean =
|
public inline fun <reified A : FlagAttribute> Attributes.hasFlag(): Boolean =
|
||||||
content.keys.any { it is A }
|
content.keys.any { it is A }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -54,7 +60,7 @@ public inline fun <reified A : FlagAttribute> Attributes.has(): Boolean =
|
|||||||
public fun <T, A : Attribute<T>> Attributes.withAttribute(
|
public fun <T, A : Attribute<T>> Attributes.withAttribute(
|
||||||
attribute: A,
|
attribute: A,
|
||||||
attrValue: T,
|
attrValue: T,
|
||||||
): Attributes = Attributes(content + (attribute to attrValue))
|
): Attributes = AttributesImpl(content + (attribute to attrValue))
|
||||||
|
|
||||||
public fun <A : Attribute<Unit>> Attributes.withAttribute(attribute: A): Attributes =
|
public fun <A : Attribute<Unit>> Attributes.withAttribute(attribute: A): Attributes =
|
||||||
withAttribute(attribute, Unit)
|
withAttribute(attribute, Unit)
|
||||||
@ -62,7 +68,7 @@ public fun <A : Attribute<Unit>> Attributes.withAttribute(attribute: A): Attribu
|
|||||||
/**
|
/**
|
||||||
* Create a new [Attributes] by modifying the current one
|
* Create a new [Attributes] by modifying the current one
|
||||||
*/
|
*/
|
||||||
public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes = Attributes {
|
public fun <T> Attributes.modify(block: AttributesBuilder<T>.() -> Unit): Attributes = Attributes<T> {
|
||||||
from(this@modify)
|
from(this@modify)
|
||||||
block()
|
block()
|
||||||
}
|
}
|
||||||
@ -70,7 +76,7 @@ public fun Attributes.modify(block: AttributesBuilder.() -> Unit): Attributes =
|
|||||||
/**
|
/**
|
||||||
* Create new [Attributes] by removing [attribute] key
|
* Create new [Attributes] by removing [attribute] key
|
||||||
*/
|
*/
|
||||||
public fun Attributes.withoutAttribute(attribute: Attribute<*>): Attributes = Attributes(content.minus(attribute))
|
public fun Attributes.withoutAttribute(attribute: Attribute<*>): Attributes = AttributesImpl(content.minus(attribute))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add an element to a [SetAttribute]
|
* Add an element to a [SetAttribute]
|
||||||
@ -80,7 +86,7 @@ public fun <T, A : SetAttribute<T>> Attributes.withAttributeElement(
|
|||||||
attrValue: T,
|
attrValue: T,
|
||||||
): Attributes {
|
): Attributes {
|
||||||
val currentSet: Set<T> = get(attribute) ?: emptySet()
|
val currentSet: Set<T> = get(attribute) ?: emptySet()
|
||||||
return Attributes(
|
return AttributesImpl(
|
||||||
content + (attribute to (currentSet + attrValue))
|
content + (attribute to (currentSet + attrValue))
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -93,9 +99,7 @@ public fun <T, A : SetAttribute<T>> Attributes.withoutAttributeElement(
|
|||||||
attrValue: T,
|
attrValue: T,
|
||||||
): Attributes {
|
): Attributes {
|
||||||
val currentSet: Set<T> = get(attribute) ?: emptySet()
|
val currentSet: Set<T> = get(attribute) ?: emptySet()
|
||||||
return Attributes(
|
return AttributesImpl(content + (attribute to (currentSet - attrValue)))
|
||||||
content + (attribute to (currentSet - attrValue))
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -104,13 +108,13 @@ public fun <T, A : SetAttribute<T>> Attributes.withoutAttributeElement(
|
|||||||
public fun <T, A : Attribute<T>> Attributes(
|
public fun <T, A : Attribute<T>> Attributes(
|
||||||
attribute: A,
|
attribute: A,
|
||||||
attrValue: T,
|
attrValue: T,
|
||||||
): Attributes = Attributes(mapOf(attribute to attrValue))
|
): Attributes = AttributesImpl(mapOf(attribute to attrValue))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create Attributes with a single [Unit] valued attribute
|
* Create Attributes with a single [Unit] valued attribute
|
||||||
*/
|
*/
|
||||||
public fun <A : Attribute<Unit>> Attributes(
|
public fun <A : Attribute<Unit>> Attributes(
|
||||||
attribute: A
|
attribute: A,
|
||||||
): Attributes = Attributes(mapOf(attribute to Unit))
|
): Attributes = AttributesImpl(mapOf(attribute to Unit))
|
||||||
|
|
||||||
public operator fun Attributes.plus(other: Attributes): Attributes = Attributes(content + other.content)
|
public operator fun Attributes.plus(other: Attributes): Attributes = AttributesImpl(content + other.content)
|
@ -10,19 +10,24 @@ package space.kscience.attributes
|
|||||||
*
|
*
|
||||||
* @param O type marker of an owner object, for which these attributes are made
|
* @param O type marker of an owner object, for which these attributes are made
|
||||||
*/
|
*/
|
||||||
public class TypedAttributesBuilder<in O> internal constructor(private val map: MutableMap<Attribute<*>, Any?>) {
|
public class AttributesBuilder<out O> internal constructor(
|
||||||
|
private val map: MutableMap<Attribute<*>, Any?>,
|
||||||
|
) : Attributes {
|
||||||
|
|
||||||
public constructor() : this(mutableMapOf())
|
public constructor() : this(mutableMapOf())
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
override val content: Map<out Attribute<*>, Any?> get() = map
|
||||||
public operator fun <T> get(attribute: Attribute<T>): T? = map[attribute] as? T
|
|
||||||
|
public operator fun <T> set(attribute: Attribute<T>, value: T?) {
|
||||||
|
if (value == null) {
|
||||||
|
map.remove(attribute)
|
||||||
|
} else {
|
||||||
|
map[attribute] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public operator fun <V> Attribute<V>.invoke(value: V?) {
|
public operator fun <V> Attribute<V>.invoke(value: V?) {
|
||||||
if (value == null) {
|
set(this, value)
|
||||||
map.remove(this)
|
|
||||||
} else {
|
|
||||||
map[this] = value
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun from(attributes: Attributes) {
|
public fun from(attributes: Attributes) {
|
||||||
@ -46,14 +51,8 @@ public class TypedAttributesBuilder<in O> internal constructor(private val map:
|
|||||||
map[this] = currentSet - attrValue
|
map[this] = currentSet - attrValue
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun build(): Attributes = Attributes(map)
|
public fun build(): Attributes = AttributesImpl(map)
|
||||||
}
|
}
|
||||||
|
|
||||||
public typealias AttributesBuilder = TypedAttributesBuilder<Any?>
|
public inline fun <O> Attributes(builder: AttributesBuilder<O>.() -> Unit): Attributes =
|
||||||
|
AttributesBuilder<O>().apply(builder).build()
|
||||||
public fun AttributesBuilder(
|
|
||||||
attributes: Attributes,
|
|
||||||
): AttributesBuilder = AttributesBuilder(attributes.content.toMutableMap())
|
|
||||||
|
|
||||||
public inline fun Attributes(builder: AttributesBuilder.() -> Unit): Attributes =
|
|
||||||
AttributesBuilder().apply(builder).build()
|
|
@ -0,0 +1,31 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2023 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.attributes
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An attribute that has a type parameter for value
|
||||||
|
* @param type parameter-type
|
||||||
|
*/
|
||||||
|
public abstract class PolymorphicAttribute<T>(public val type: SafeType<T>) : Attribute<T> {
|
||||||
|
override fun equals(other: Any?): Boolean = other != null &&
|
||||||
|
(this::class == other::class) &&
|
||||||
|
(other as? PolymorphicAttribute<*>)?.type == this.type
|
||||||
|
|
||||||
|
override fun hashCode(): Int = this::class.hashCode() + type.hashCode()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a polymorphic attribute using attribute factory
|
||||||
|
*/
|
||||||
|
public operator fun <T> Attributes.get(attributeKeyBuilder: () -> PolymorphicAttribute<T>): T? = get(attributeKeyBuilder())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set a polymorphic attribute using its factory
|
||||||
|
*/
|
||||||
|
public operator fun <O, T> AttributesBuilder<O>.set(attributeKeyBuilder: () -> PolymorphicAttribute<T>, value: T) {
|
||||||
|
set(attributeKeyBuilder(), value)
|
||||||
|
}
|
@ -19,6 +19,8 @@ public class Ejml${type}Vector<out M : $ejmlMatrixType>(override val origin: M)
|
|||||||
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
|
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val type: SafeType<${type}> get() = safeTypeOf()
|
||||||
|
|
||||||
override operator fun get(index: Int): $type = origin[0, index]
|
override operator fun get(index: Int): $type = origin[0, index]
|
||||||
}"""
|
}"""
|
||||||
appendLine(text)
|
appendLine(text)
|
||||||
@ -30,6 +32,8 @@ private fun Appendable.appendEjmlMatrix(type: String, ejmlMatrixType: String) {
|
|||||||
* [EjmlMatrix] specialization for [$type].
|
* [EjmlMatrix] specialization for [$type].
|
||||||
*/
|
*/
|
||||||
public class Ejml${type}Matrix<out M : $ejmlMatrixType>(override val origin: M) : EjmlMatrix<$type, M>(origin) {
|
public class Ejml${type}Matrix<out M : $ejmlMatrixType>(override val origin: M) : EjmlMatrix<$type, M>(origin) {
|
||||||
|
override val type: SafeType<${type}> get() = safeTypeOf()
|
||||||
|
|
||||||
override operator fun get(i: Int, j: Int): $type = origin[i, j]
|
override operator fun get(i: Int, j: Int): $type = origin[i, j]
|
||||||
}"""
|
}"""
|
||||||
appendLine(text)
|
appendLine(text)
|
||||||
@ -46,7 +50,9 @@ private fun Appendable.appendEjmlLinearSpace(
|
|||||||
denseOps: String,
|
denseOps: String,
|
||||||
isDense: Boolean,
|
isDense: Boolean,
|
||||||
) {
|
) {
|
||||||
@Language("kotlin") val text = """/**
|
@Language("kotlin") val text = """
|
||||||
|
|
||||||
|
/**
|
||||||
* [EjmlLinearSpace] implementation based on [CommonOps_$ops], [DecompositionFactory_${ops}] operations and
|
* [EjmlLinearSpace] implementation based on [CommonOps_$ops], [DecompositionFactory_${ops}] operations and
|
||||||
* [${ejmlMatrixType}] matrices.
|
* [${ejmlMatrixType}] matrices.
|
||||||
*/
|
*/
|
||||||
@ -56,7 +62,7 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra},
|
|||||||
*/
|
*/
|
||||||
override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra
|
override val elementAlgebra: $kmathAlgebra get() = $kmathAlgebra
|
||||||
|
|
||||||
override val elementType: KType get() = typeOf<$type>()
|
override val type: SafeType<${type}> get() = safeTypeOf()
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when {
|
override fun Matrix<${type}>.toEjml(): Ejml${type}Matrix<${ejmlMatrixType}> = when {
|
||||||
@ -385,6 +391,8 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC
|
|||||||
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
|
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
|
||||||
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
|
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
|
||||||
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
|
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.linear.*
|
import space.kscience.kmath.linear.*
|
||||||
import space.kscience.kmath.linear.Matrix
|
import space.kscience.kmath.linear.Matrix
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
|
@ -15,7 +15,7 @@ import space.kscience.kmath.operations.asIterable
|
|||||||
import space.kscience.kmath.operations.toList
|
import space.kscience.kmath.operations.toList
|
||||||
import space.kscience.kmath.optimization.FunctionOptimizationTarget
|
import space.kscience.kmath.optimization.FunctionOptimizationTarget
|
||||||
import space.kscience.kmath.optimization.optimizeWith
|
import space.kscience.kmath.optimization.optimizeWith
|
||||||
import space.kscience.kmath.optimization.resultPoint
|
import space.kscience.kmath.optimization.result
|
||||||
import space.kscience.kmath.optimization.resultValue
|
import space.kscience.kmath.optimization.resultValue
|
||||||
import space.kscience.kmath.random.RandomGenerator
|
import space.kscience.kmath.random.RandomGenerator
|
||||||
import space.kscience.kmath.real.DoubleVector
|
import space.kscience.kmath.real.DoubleVector
|
||||||
@ -98,7 +98,7 @@ suspend fun main() {
|
|||||||
scatter {
|
scatter {
|
||||||
mode = ScatterMode.lines
|
mode = ScatterMode.lines
|
||||||
x(x)
|
x(x)
|
||||||
y(x.map { result.resultPoint[a]!! * it.pow(2) + result.resultPoint[b]!! * it + 1 })
|
y(x.map { result.result[a]!! * it.pow(2) + result.result[b]!! * it + 1 })
|
||||||
name = "fit"
|
name = "fit"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,13 +94,13 @@ suspend fun main() {
|
|||||||
scatter {
|
scatter {
|
||||||
mode = ScatterMode.lines
|
mode = ScatterMode.lines
|
||||||
x(x)
|
x(x)
|
||||||
y(x.map { result.model(result.startPoint + result.resultPoint + (Symbol.x to it)) })
|
y(x.map { result.model(result.startPoint + result.result + (Symbol.x to it)) })
|
||||||
name = "fit"
|
name = "fit"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
br()
|
br()
|
||||||
h3 {
|
h3 {
|
||||||
+"Fit result: ${result.resultPoint}"
|
+"Fit result: ${result.result}"
|
||||||
}
|
}
|
||||||
h3 {
|
h3 {
|
||||||
+"Chi2/dof = ${result.chiSquaredOrNull!! / result.dof}"
|
+"Chi2/dof = ${result.chiSquaredOrNull!! / result.dof}"
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.ast
|
package space.kscience.kmath.ast
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.WithType
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
@ -15,7 +17,7 @@ import space.kscience.kmath.operations.NumericAlgebra
|
|||||||
*
|
*
|
||||||
* @param T the type.
|
* @param T the type.
|
||||||
*/
|
*/
|
||||||
public sealed interface TypedMst<T> {
|
public sealed interface TypedMst<T> : WithType<T> {
|
||||||
/**
|
/**
|
||||||
* A node containing a unary operation.
|
* A node containing a unary operation.
|
||||||
*
|
*
|
||||||
@ -24,8 +26,13 @@ public sealed interface TypedMst<T> {
|
|||||||
* @property function The function implementing this operation.
|
* @property function The function implementing this operation.
|
||||||
* @property value The argument of this operation.
|
* @property value The argument of this operation.
|
||||||
*/
|
*/
|
||||||
public class Unary<T>(public val operation: String, public val function: (T) -> T, public val value: TypedMst<T>) :
|
public class Unary<T>(
|
||||||
TypedMst<T> {
|
public val operation: String,
|
||||||
|
public val function: (T) -> T,
|
||||||
|
public val value: TypedMst<T>,
|
||||||
|
) : TypedMst<T> {
|
||||||
|
override val type: SafeType<T> get() = value.type
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other == null || this::class != other::class) return false
|
if (other == null || this::class != other::class) return false
|
||||||
@ -59,6 +66,13 @@ public sealed interface TypedMst<T> {
|
|||||||
public val left: TypedMst<T>,
|
public val left: TypedMst<T>,
|
||||||
public val right: TypedMst<T>,
|
public val right: TypedMst<T>,
|
||||||
) : TypedMst<T> {
|
) : TypedMst<T> {
|
||||||
|
|
||||||
|
init {
|
||||||
|
require(left.type==right.type){"Left and right expressions must be of the same type"}
|
||||||
|
}
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = left.type
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other == null || this::class != other::class) return false
|
if (other == null || this::class != other::class) return false
|
||||||
@ -89,7 +103,12 @@ public sealed interface TypedMst<T> {
|
|||||||
* @property value The held value.
|
* @property value The held value.
|
||||||
* @property number The number this value corresponds.
|
* @property number The number this value corresponds.
|
||||||
*/
|
*/
|
||||||
public class Constant<T>(public val value: T, public val number: Number?) : TypedMst<T> {
|
public class Constant<T>(
|
||||||
|
override val type: SafeType<T>,
|
||||||
|
public val value: T,
|
||||||
|
public val number: Number?,
|
||||||
|
) : TypedMst<T> {
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other == null || this::class != other::class) return false
|
if (other == null || this::class != other::class) return false
|
||||||
@ -114,7 +133,7 @@ public sealed interface TypedMst<T> {
|
|||||||
* @param T the type.
|
* @param T the type.
|
||||||
* @property symbol The symbol of the variable.
|
* @property symbol The symbol of the variable.
|
||||||
*/
|
*/
|
||||||
public class Variable<T>(public val symbol: Symbol) : TypedMst<T> {
|
public class Variable<T>(override val type: SafeType<T>, public val symbol: Symbol) : TypedMst<T> {
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
if (other == null || this::class != other::class) return false
|
if (other == null || this::class != other::class) return false
|
||||||
@ -167,6 +186,7 @@ public fun <T> TypedMst<T>.interpret(algebra: Algebra<T>, vararg arguments: Pair
|
|||||||
/**
|
/**
|
||||||
* Interpret this [TypedMst] node as expression.
|
* Interpret this [TypedMst] node as expression.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> = Expression { arguments ->
|
public fun <T : Any> TypedMst<T>.toExpression(algebra: Algebra<T>): Expression<T> =
|
||||||
|
Expression(algebra.type) { arguments ->
|
||||||
interpret(algebra, arguments)
|
interpret(algebra, arguments)
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ import space.kscience.kmath.operations.bindSymbolOrNull
|
|||||||
*/
|
*/
|
||||||
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
|
public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (this) {
|
||||||
is MST.Numeric -> TypedMst.Constant(
|
is MST.Numeric -> TypedMst.Constant(
|
||||||
|
algebra.type,
|
||||||
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
|
(algebra as? NumericAlgebra<T>)?.number(value) ?: error("Numeric nodes are not supported by $algebra"),
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
@ -27,7 +28,7 @@ public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (t
|
|||||||
arg.value,
|
arg.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
TypedMst.Constant(value, if (value is Number) value else null)
|
TypedMst.Constant(algebra.type, value, if (value is Number) value else null)
|
||||||
}
|
}
|
||||||
|
|
||||||
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
|
else -> TypedMst.Unary(operation, algebra.unaryOperationFunction(operation), arg)
|
||||||
@ -59,7 +60,7 @@ public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (t
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
TypedMst.Constant(value, if (value is Number) value else null)
|
TypedMst.Constant(algebra.type, value, if (value is Number) value else null)
|
||||||
}
|
}
|
||||||
|
|
||||||
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
|
algebra is NumericAlgebra && left is TypedMst.Constant && left.number != null -> TypedMst.Binary(
|
||||||
@ -84,8 +85,8 @@ public fun <T> MST.evaluateConstants(algebra: Algebra<T>): TypedMst<T> = when (t
|
|||||||
val boundSymbol = algebra.bindSymbolOrNull(this)
|
val boundSymbol = algebra.bindSymbolOrNull(this)
|
||||||
|
|
||||||
if (boundSymbol != null)
|
if (boundSymbol != null)
|
||||||
TypedMst.Constant(boundSymbol, if (boundSymbol is Number) boundSymbol else null)
|
TypedMst.Constant(algebra.type, boundSymbol, if (boundSymbol is Number) boundSymbol else null)
|
||||||
else
|
else
|
||||||
TypedMst.Variable(this)
|
TypedMst.Variable(algebra.type, this)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ import space.kscience.kmath.operations.Algebra
|
|||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
|
public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T> {
|
||||||
val typed = evaluateConstants(algebra)
|
val typed = evaluateConstants(algebra)
|
||||||
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
|
if (typed is TypedMst.Constant<T>) return Expression(algebra.type) { typed.value }
|
||||||
|
|
||||||
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
|
fun ESTreeBuilder<T>.visit(node: TypedMst<T>): BaseExpression = when (node) {
|
||||||
is TypedMst.Constant -> constant(node.value)
|
is TypedMst.Constant -> constant(node.value)
|
||||||
@ -36,7 +36,7 @@ public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T>
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ESTreeBuilder { visit(typed) }.instance
|
return ESTreeBuilder(algebra.type) { visit(typed) }.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5,13 +5,22 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.estree.internal
|
package space.kscience.kmath.estree.internal
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.WithType
|
||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.internal.astring.generate
|
import space.kscience.kmath.internal.astring.generate
|
||||||
import space.kscience.kmath.internal.estree.*
|
import space.kscience.kmath.internal.estree.*
|
||||||
|
|
||||||
internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExpression) {
|
internal class ESTreeBuilder<T>(
|
||||||
private class GeneratedExpression<T>(val executable: dynamic, val constants: Array<dynamic>) : Expression<T> {
|
override val type: SafeType<T>,
|
||||||
|
val bodyCallback: ESTreeBuilder<T>.() -> BaseExpression,
|
||||||
|
) : WithType<T> {
|
||||||
|
private class GeneratedExpression<T>(
|
||||||
|
override val type: SafeType<T>,
|
||||||
|
val executable: dynamic,
|
||||||
|
val constants: Array<dynamic>,
|
||||||
|
) : Expression<T> {
|
||||||
@Suppress("UNUSED_VARIABLE")
|
@Suppress("UNUSED_VARIABLE")
|
||||||
override fun invoke(arguments: Map<Symbol, T>): T {
|
override fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
val e = executable
|
val e = executable
|
||||||
@ -30,7 +39,7 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
|
|||||||
)
|
)
|
||||||
|
|
||||||
val code = generate(node)
|
val code = generate(node)
|
||||||
GeneratedExpression(js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray())
|
GeneratedExpression(type, js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
private val constants = mutableListOf<Any>()
|
private val constants = mutableListOf<Any>()
|
||||||
|
@ -29,7 +29,7 @@ import space.kscience.kmath.operations.Int64Ring
|
|||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||||
val typed = evaluateConstants(algebra)
|
val typed = evaluateConstants(algebra)
|
||||||
if (typed is TypedMst.Constant<T>) return Expression { typed.value }
|
if (typed is TypedMst.Constant<T>) return Expression(algebra.type) { typed.value }
|
||||||
|
|
||||||
fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
|
fun GenericAsmBuilder<T>.variablesVisitor(node: TypedMst<T>): Unit = when (node) {
|
||||||
is TypedMst.Unary -> variablesVisitor(node.value)
|
is TypedMst.Unary -> variablesVisitor(node.value)
|
||||||
|
@ -8,10 +8,13 @@
|
|||||||
package space.kscience.kmath.commons.expressions
|
package space.kscience.kmath.commons.expressions
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.operations.NumbersAddOps
|
import space.kscience.kmath.operations.NumbersAddOps
|
||||||
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field over commons-math [DerivativeStructure].
|
* A field over commons-math [DerivativeStructure].
|
||||||
@ -26,6 +29,9 @@ public class CmDsField(
|
|||||||
bindings: Map<Symbol, Double>,
|
bindings: Map<Symbol, Double>,
|
||||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
||||||
NumbersAddOps<DerivativeStructure> {
|
NumbersAddOps<DerivativeStructure> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<DerivativeStructure> = MutableBufferFactory()
|
||||||
|
|
||||||
public val numberOfVariables: Int = bindings.size
|
public val numberOfVariables: Int = bindings.size
|
||||||
|
|
||||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||||
@ -77,7 +83,9 @@ public class CmDsField(
|
|||||||
|
|
||||||
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
|
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
|
||||||
|
|
||||||
override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right)
|
override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure =
|
||||||
|
left.multiply(right)
|
||||||
|
|
||||||
override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
|
override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
|
||||||
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||||
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||||
@ -125,13 +133,16 @@ public object CmDsProcessor : AutoDiffProcessor<Double, DerivativeStructure, CmD
|
|||||||
public class CmDsExpression(
|
public class CmDsExpression(
|
||||||
public val function: CmDsField.() -> DerivativeStructure,
|
public val function: CmDsField.() -> DerivativeStructure,
|
||||||
) : DifferentiableExpression<Double> {
|
) : DifferentiableExpression<Double> {
|
||||||
|
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
override operator fun invoke(arguments: Map<Symbol, Double>): Double =
|
||||||
CmDsField(0, arguments).function().value
|
CmDsField(0, arguments).function().value
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the derivative expression with given orders
|
* Get the derivative expression with given orders
|
||||||
*/
|
*/
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression(type) { arguments ->
|
||||||
with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) }
|
with(CmDsField(symbols.size, arguments)) { function().derivative(symbols) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ public class CMGaussRuleIntegrator(
|
|||||||
private var type: GaussRule = GaussRule.LEGENDRE,
|
private var type: GaussRule = GaussRule.LEGENDRE,
|
||||||
) : UnivariateIntegrator<Double> {
|
) : UnivariateIntegrator<Double> {
|
||||||
|
|
||||||
override fun process(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> {
|
override fun integrate(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> {
|
||||||
val range = integrand[IntegrationRange]
|
val range = integrand[IntegrationRange]
|
||||||
?: error("Integration range is not provided")
|
?: error("Integration range is not provided")
|
||||||
val integrator: GaussIntegrator = getIntegrator(range)
|
val integrator: GaussIntegrator = getIntegrator(range)
|
||||||
@ -79,7 +79,7 @@ public class CMGaussRuleIntegrator(
|
|||||||
numPoints: Int = 100,
|
numPoints: Int = 100,
|
||||||
type: GaussRule = GaussRule.LEGENDRE,
|
type: GaussRule = GaussRule.LEGENDRE,
|
||||||
function: (Double) -> Double,
|
function: (Double) -> Double,
|
||||||
): Double = CMGaussRuleIntegrator(numPoints, type).process(
|
): Double = CMGaussRuleIntegrator(numPoints, type).integrate(
|
||||||
UnivariateIntegrand({IntegrationRange(range)},function)
|
UnivariateIntegrand({IntegrationRange(range)},function)
|
||||||
).value
|
).value
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.commons.integration
|
|||||||
|
|
||||||
import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator
|
import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator
|
||||||
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator
|
import org.apache.commons.math3.analysis.integration.SimpsonIntegrator
|
||||||
|
import space.kscience.attributes.Attributes
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.integration.*
|
import space.kscience.kmath.integration.*
|
||||||
import org.apache.commons.math3.analysis.integration.UnivariateIntegrator as CMUnivariateIntegrator
|
import org.apache.commons.math3.analysis.integration.UnivariateIntegrator as CMUnivariateIntegrator
|
||||||
@ -19,7 +20,7 @@ public class CMIntegrator(
|
|||||||
public val integratorBuilder: (Integrand<Double>) -> CMUnivariateIntegrator,
|
public val integratorBuilder: (Integrand<Double>) -> CMUnivariateIntegrator,
|
||||||
) : UnivariateIntegrator<Double> {
|
) : UnivariateIntegrator<Double> {
|
||||||
|
|
||||||
override fun process(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> {
|
override fun integrate(integrand: UnivariateIntegrand<Double>): UnivariateIntegrand<Double> {
|
||||||
val integrator = integratorBuilder(integrand)
|
val integrator = integratorBuilder(integrand)
|
||||||
val maxCalls = integrand[IntegrandMaxCalls] ?: defaultMaxCalls
|
val maxCalls = integrand[IntegrandMaxCalls] ?: defaultMaxCalls
|
||||||
val remainingCalls = maxCalls - integrand.calls
|
val remainingCalls = maxCalls - integrand.calls
|
||||||
@ -73,15 +74,9 @@ public class CMIntegrator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public var MutableList<IntegrandFeature>.targetAbsoluteAccuracy: Double?
|
public val Attributes.targetAbsoluteAccuracy: Double?
|
||||||
get() = filterIsInstance<IntegrandAbsoluteAccuracy>().lastOrNull()?.accuracy
|
get() = get(IntegrandAbsoluteAccuracy)
|
||||||
set(value) {
|
|
||||||
value?.let { add(IntegrandAbsoluteAccuracy(value)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public var MutableList<IntegrandFeature>.targetRelativeAccuracy: Double?
|
public val Attributes.targetRelativeAccuracy: Double?
|
||||||
get() = filterIsInstance<IntegrandRelativeAccuracy>().lastOrNull()?.accuracy
|
get() = get(IntegrandRelativeAccuracy)
|
||||||
set(value) {
|
|
||||||
value?.let { add(IntegrandRelativeAccuracy(value)) }
|
|
||||||
}
|
|
||||||
|
@ -6,18 +6,20 @@
|
|||||||
package space.kscience.kmath.commons.linear
|
package space.kscience.kmath.commons.linear
|
||||||
|
|
||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
|
import org.apache.commons.math3.linear.LUDecomposition
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.linear.*
|
import space.kscience.kmath.linear.*
|
||||||
|
import space.kscience.kmath.nd.Structure2D
|
||||||
import space.kscience.kmath.nd.StructureAttribute
|
import space.kscience.kmath.nd.StructureAttribute
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.Float64Field
|
import space.kscience.kmath.operations.Float64Field
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.Float64Buffer
|
import space.kscience.kmath.structures.Float64Buffer
|
||||||
import kotlin.reflect.KClass
|
|
||||||
import kotlin.reflect.KType
|
|
||||||
import kotlin.reflect.cast
|
import kotlin.reflect.cast
|
||||||
import kotlin.reflect.typeOf
|
|
||||||
|
|
||||||
public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
|
public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
override val rowNum: Int get() = origin.rowDimension
|
override val rowNum: Int get() = origin.rowDimension
|
||||||
override val colNum: Int get() = origin.columnDimension
|
override val colNum: Int get() = origin.columnDimension
|
||||||
|
|
||||||
@ -26,6 +28,7 @@ public class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
|
|||||||
|
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class CMVector(public val origin: RealVector) : Point<Double> {
|
public value class CMVector(public val origin: RealVector) : Point<Double> {
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
override val size: Int get() = origin.dimension
|
override val size: Int get() = origin.dimension
|
||||||
|
|
||||||
override operator fun get(index: Int): Double = origin.getEntry(index)
|
override operator fun get(index: Int): Double = origin.getEntry(index)
|
||||||
@ -40,7 +43,7 @@ public fun RealVector.toPoint(): CMVector = CMVector(this)
|
|||||||
public object CMLinearSpace : LinearSpace<Double, Float64Field> {
|
public object CMLinearSpace : LinearSpace<Double, Float64Field> {
|
||||||
override val elementAlgebra: Float64Field get() = Float64Field
|
override val elementAlgebra: Float64Field get() = Float64Field
|
||||||
|
|
||||||
override val elementType: KType = typeOf<Double>()
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
|
|
||||||
override fun buildMatrix(
|
override fun buildMatrix(
|
||||||
rows: Int,
|
rows: Int,
|
||||||
@ -102,19 +105,14 @@ public object CMLinearSpace : LinearSpace<Double, Float64Field> {
|
|||||||
override fun Double.times(v: Point<Double>): CMVector =
|
override fun Double.times(v: Point<Double>): CMVector =
|
||||||
v * this
|
v * this
|
||||||
|
|
||||||
@UnstableKMathAPI
|
override fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<Double>, attribute: A): V? {
|
||||||
override fun <F : StructureAttribute> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
|
|
||||||
//Return the feature if it is intrinsic to the structure
|
|
||||||
structure.getFeature(type)?.let { return it }
|
|
||||||
|
|
||||||
val origin = structure.toCM().origin
|
val origin = structure.toCM().origin
|
||||||
|
|
||||||
return when (type) {
|
return when (attribute) {
|
||||||
IsDiagonal::class -> if (origin is DiagonalMatrix) IsDiagonal else null
|
IsDiagonal -> if (origin is DiagonalMatrix) IsDiagonal else null
|
||||||
|
Determinant -> LUDecomposition(origin).determinant
|
||||||
Determinant::class, LupDecompositionAttribute::class -> object :
|
LUP -> GenericLupDecomposition {
|
||||||
Determinant<Double>,
|
|
||||||
LupDecompositionAttribute<Double> {
|
|
||||||
private val lup by lazy { LUDecomposition(origin) }
|
private val lup by lazy { LUDecomposition(origin) }
|
||||||
override val determinant: Double by lazy { lup.determinant }
|
override val determinant: Double by lazy { lup.determinant }
|
||||||
override val l: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(lup.l).withAttribute(LowerTriangular) }
|
override val l: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(lup.l).withAttribute(LowerTriangular) }
|
||||||
@ -122,20 +120,24 @@ public object CMLinearSpace : LinearSpace<Double, Float64Field> {
|
|||||||
override val p: Matrix<Double> by lazy { CMMatrix(lup.p) }
|
override val p: Matrix<Double> by lazy { CMMatrix(lup.p) }
|
||||||
}
|
}
|
||||||
|
|
||||||
CholeskyDecompositionAttribute::class -> object : CholeskyDecompositionAttribute<Double> {
|
CholeskyDecompositionAttribute -> object : CholeskyDecompositionAttribute<Double> {
|
||||||
override val l: Matrix<Double> by lazy<Matrix<Double>> {
|
override val l: Matrix<Double> by lazy<Matrix<Double>> {
|
||||||
val cholesky = CholeskyDecomposition(origin)
|
val cholesky = CholeskyDecomposition(origin)
|
||||||
CMMatrix(cholesky.l).withAttribute(LowerTriangular)
|
CMMatrix(cholesky.l).withAttribute(LowerTriangular)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
QRDecompositionAttribute::class -> object : QRDecompositionAttribute<Double> {
|
QRDecompositionAttribute -> object : QRDecompositionAttribute<Double> {
|
||||||
private val qr by lazy { QRDecomposition(origin) }
|
private val qr by lazy { QRDecomposition(origin) }
|
||||||
override val q: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(qr.q).withAttribute(OrthogonalAttribute) }
|
override val q: Matrix<Double> by lazy<Matrix<Double>> {
|
||||||
|
CMMatrix(qr.q).withAttribute(
|
||||||
|
OrthogonalAttribute
|
||||||
|
)
|
||||||
|
}
|
||||||
override val r: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(qr.r).withAttribute(UpperTriangular) }
|
override val r: Matrix<Double> by lazy<Matrix<Double>> { CMMatrix(qr.r).withAttribute(UpperTriangular) }
|
||||||
}
|
}
|
||||||
|
|
||||||
SVDAttribute::class -> object : SVDAttribute<Double> {
|
SVDAttribute -> object : SVDAttribute<Double> {
|
||||||
private val sv by lazy { SingularValueDecomposition(origin) }
|
private val sv by lazy { SingularValueDecomposition(origin) }
|
||||||
override val u: Matrix<Double> by lazy { CMMatrix(sv.u) }
|
override val u: Matrix<Double> by lazy { CMMatrix(sv.u) }
|
||||||
override val s: Matrix<Double> by lazy { CMMatrix(sv.s) }
|
override val s: Matrix<Double> by lazy { CMMatrix(sv.s) }
|
||||||
|
@ -13,6 +13,8 @@ import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient
|
|||||||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex
|
||||||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer
|
||||||
|
import space.kscience.attributes.AttributesBuilder
|
||||||
|
import space.kscience.attributes.SetAttribute
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.expressions.SymbolIndexer
|
import space.kscience.kmath.expressions.SymbolIndexer
|
||||||
@ -26,34 +28,25 @@ import kotlin.reflect.KClass
|
|||||||
public operator fun PointValuePair.component1(): DoubleArray = point
|
public operator fun PointValuePair.component1(): DoubleArray = point
|
||||||
public operator fun PointValuePair.component2(): Double = value
|
public operator fun PointValuePair.component2(): Double = value
|
||||||
|
|
||||||
public class CMOptimizerEngine(public val optimizerBuilder: () -> MultivariateOptimizer) : OptimizationFeature {
|
public object CMOptimizerEngine: OptimizationAttribute<() -> MultivariateOptimizer>
|
||||||
override fun toString(): String = "CMOptimizer($optimizerBuilder)"
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify a Commons-maths optimization engine
|
* Specify a Commons-maths optimization engine
|
||||||
*/
|
*/
|
||||||
public fun FunctionOptimizationBuilder<Double>.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) {
|
public fun AttributesBuilder<FunctionOptimization<Double>>.cmEngine(optimizerBuilder: () -> MultivariateOptimizer) {
|
||||||
addFeature(CMOptimizerEngine(optimizerBuilder))
|
set(CMOptimizerEngine, optimizerBuilder)
|
||||||
}
|
}
|
||||||
|
|
||||||
public class CMOptimizerData(public val data: List<SymbolIndexer.() -> OptimizationData>) : OptimizationFeature {
|
public object CMOptimizerData: SetAttribute<SymbolIndexer.() -> OptimizationData>
|
||||||
public constructor(vararg data: (SymbolIndexer.() -> OptimizationData)) : this(data.toList())
|
|
||||||
|
|
||||||
override fun toString(): String = "CMOptimizerData($data)"
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify Commons-maths optimization data.
|
* Specify Commons-maths optimization data.
|
||||||
*/
|
*/
|
||||||
public fun FunctionOptimizationBuilder<Double>.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) {
|
public fun AttributesBuilder<FunctionOptimization<Double>>.cmOptimizationData(data: SymbolIndexer.() -> OptimizationData) {
|
||||||
updateFeature<CMOptimizerData> {
|
CMOptimizerData.add(data)
|
||||||
val newData = (it?.data ?: emptyList()) + data
|
|
||||||
CMOptimizerData(newData)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun FunctionOptimizationBuilder<Double>.simplexSteps(vararg steps: Pair<Symbol, Double>) {
|
public fun AttributesBuilder<FunctionOptimization<Double>>.simplexSteps(vararg steps: Pair<Symbol, Double>) {
|
||||||
//TODO use convergence checker from features
|
//TODO use convergence checker from features
|
||||||
cmEngine { SimplexOptimizer(CMOptimizer.defaultConvergenceChecker) }
|
cmEngine { SimplexOptimizer(CMOptimizer.defaultConvergenceChecker) }
|
||||||
cmOptimizationData { NelderMeadSimplex(mapOf(*steps).toDoubleArray()) }
|
cmOptimizationData { NelderMeadSimplex(mapOf(*steps).toDoubleArray()) }
|
||||||
@ -78,8 +71,8 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
|
|||||||
): FunctionOptimization<Double> {
|
): FunctionOptimization<Double> {
|
||||||
val startPoint = problem.startPoint
|
val startPoint = problem.startPoint
|
||||||
|
|
||||||
val parameters = problem.getFeature<OptimizationParameters>()?.symbols
|
val parameters = problem.attributes[OptimizationParameters]
|
||||||
?: problem.getFeature<OptimizationStartPoint<Double>>()?.point?.keys
|
?: problem.attributes[OptimizationStartPoint<Double>()]?.keys
|
||||||
?: startPoint.keys
|
?: startPoint.keys
|
||||||
|
|
||||||
|
|
||||||
@ -90,7 +83,7 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
|
|||||||
DEFAULT_MAX_ITER
|
DEFAULT_MAX_ITER
|
||||||
)
|
)
|
||||||
|
|
||||||
val cmOptimizer: MultivariateOptimizer = problem.getFeature<CMOptimizerEngine>()?.optimizerBuilder?.invoke()
|
val cmOptimizer: MultivariateOptimizer = problem.attributes[CMOptimizerEngine]?.invoke()
|
||||||
?: NonLinearConjugateGradientOptimizer(
|
?: NonLinearConjugateGradientOptimizer(
|
||||||
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES,
|
||||||
convergenceChecker
|
convergenceChecker
|
||||||
@ -123,7 +116,7 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
|
|||||||
}
|
}
|
||||||
addOptimizationData(gradientFunction)
|
addOptimizationData(gradientFunction)
|
||||||
|
|
||||||
val logger = problem.getFeature<OptimizationLog>()
|
val logger = problem.attributes[OptimizationLog]
|
||||||
|
|
||||||
for (feature in problem.attributes) {
|
for (feature in problem.attributes) {
|
||||||
when (feature) {
|
when (feature) {
|
||||||
@ -139,7 +132,7 @@ public object CMOptimizer : Optimizer<Double, FunctionOptimization<Double>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray())
|
val (point, value) = cmOptimizer.optimize(*optimizationData.values.toTypedArray())
|
||||||
return problem.withFeatures(OptimizationResult(point.toMap()), OptimizationValue(value))
|
return problem.withAttributes(OptimizationResult(point.toMap()), OptimizationValue(value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ internal class OptimizeTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testGradientOptimization() = runBlocking {
|
fun testGradientOptimization() = runBlocking {
|
||||||
val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0)
|
val result = normal.optimizeWith(CMOptimizer, x to 1.0, y to 1.0)
|
||||||
println(result.resultPoint)
|
println(result.result)
|
||||||
println(result.resultValue)
|
println(result.resultValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ internal class OptimizeTest {
|
|||||||
//this sets simplex optimizer
|
//this sets simplex optimizer
|
||||||
}
|
}
|
||||||
|
|
||||||
println(result.resultPoint)
|
println(result.result)
|
||||||
println(result.resultValue)
|
println(result.resultValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
@ -331,10 +332,13 @@ public class DerivativeStructureRingExpression<T, A>(
|
|||||||
public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
|
public val elementBufferFactory: MutableBufferFactory<T> = algebra.bufferFactory,
|
||||||
public val function: DSRing<T, A>.() -> DS<T, A>,
|
public val function: DSRing<T, A>.() -> DS<T, A>,
|
||||||
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
) : DifferentiableExpression<T> where A : Ring<T>, A : ScaleOperations<T>, A : NumericAlgebra<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = elementBufferFactory.type
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
DSRing(algebra, 0, arguments).function().value
|
DSRing(algebra, 0, arguments).function().value
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression(type) { arguments ->
|
||||||
with(
|
with(
|
||||||
DSRing(
|
DSRing(
|
||||||
algebra,
|
algebra,
|
||||||
@ -443,10 +447,13 @@ public class DSFieldExpression<T, A : ExtendedField<T>>(
|
|||||||
public val algebra: A,
|
public val algebra: A,
|
||||||
public val function: DSField<T, A>.() -> DS<T, A>,
|
public val function: DSField<T, A>.() -> DS<T, A>,
|
||||||
) : DifferentiableExpression<T> {
|
) : DifferentiableExpression<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = algebra.type
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
override operator fun invoke(arguments: Map<Symbol, T>): T =
|
||||||
DSField(algebra, 0, arguments).function().value
|
DSField(algebra, 0, arguments).function().value
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T> = Expression(type) { arguments ->
|
||||||
DSField(
|
DSField(
|
||||||
algebra,
|
algebra,
|
||||||
symbols.size,
|
symbols.size,
|
||||||
|
@ -5,8 +5,13 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.WithType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Algebra
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
import space.kscience.kmath.operations.LongRing
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
|
|
||||||
@ -15,7 +20,7 @@ import kotlin.properties.ReadOnlyProperty
|
|||||||
*
|
*
|
||||||
* @param T the type this expression takes as argument and returns.
|
* @param T the type this expression takes as argument and returns.
|
||||||
*/
|
*/
|
||||||
public fun interface Expression<T> {
|
public interface Expression<T> : WithType<T> {
|
||||||
/**
|
/**
|
||||||
* Calls this expression from arguments.
|
* Calls this expression from arguments.
|
||||||
*
|
*
|
||||||
@ -25,11 +30,20 @@ public fun interface Expression<T> {
|
|||||||
public operator fun invoke(arguments: Map<Symbol, T>): T
|
public operator fun invoke(arguments: Map<Symbol, T>): T
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun <T> Expression(type: SafeType<T>, block: (Map<Symbol, T>) -> T): Expression<T> = object : Expression<T> {
|
||||||
|
override fun invoke(arguments: Map<Symbol, T>): T = block(arguments)
|
||||||
|
|
||||||
|
override val type: SafeType<T> = type
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specialization of [Expression] for [Double] allowing better performance because of using array.
|
* Specialization of [Expression] for [Double] allowing better performance because of using array.
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public interface DoubleExpression : Expression<Double> {
|
public interface DoubleExpression : Expression<Double> {
|
||||||
|
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
||||||
*
|
*
|
||||||
@ -49,7 +63,7 @@ public interface DoubleExpression : Expression<Double> {
|
|||||||
*/
|
*/
|
||||||
public operator fun invoke(arguments: DoubleArray): Double
|
public operator fun invoke(arguments: DoubleArray): Double
|
||||||
|
|
||||||
public companion object{
|
public companion object {
|
||||||
internal val EMPTY_DOUBLE_ARRAY = DoubleArray(0)
|
internal val EMPTY_DOUBLE_ARRAY = DoubleArray(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -59,6 +73,9 @@ public interface DoubleExpression : Expression<Double> {
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public interface IntExpression : Expression<Int> {
|
public interface IntExpression : Expression<Int> {
|
||||||
|
|
||||||
|
override val type: SafeType<Int> get() = IntRing.type
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
||||||
*
|
*
|
||||||
@ -78,7 +95,7 @@ public interface IntExpression : Expression<Int> {
|
|||||||
*/
|
*/
|
||||||
public operator fun invoke(arguments: IntArray): Int
|
public operator fun invoke(arguments: IntArray): Int
|
||||||
|
|
||||||
public companion object{
|
public companion object {
|
||||||
internal val EMPTY_INT_ARRAY = IntArray(0)
|
internal val EMPTY_INT_ARRAY = IntArray(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,6 +105,9 @@ public interface IntExpression : Expression<Int> {
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public interface LongExpression : Expression<Long> {
|
public interface LongExpression : Expression<Long> {
|
||||||
|
|
||||||
|
override val type: SafeType<Long> get() = LongRing.type
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
* The indexer of this expression's arguments that should be used to build array for [invoke].
|
||||||
*
|
*
|
||||||
@ -107,7 +127,7 @@ public interface LongExpression : Expression<Long> {
|
|||||||
*/
|
*/
|
||||||
public operator fun invoke(arguments: LongArray): Long
|
public operator fun invoke(arguments: LongArray): Long
|
||||||
|
|
||||||
public companion object{
|
public companion object {
|
||||||
internal val EMPTY_LONG_ARRAY = LongArray(0)
|
internal val EMPTY_LONG_ARRAY = LongArray(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,7 +178,6 @@ public operator fun <T> Expression<T>.invoke(vararg pairs: Pair<String, T>): T =
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls this expression without providing any arguments.
|
* Calls this expression without providing any arguments.
|
||||||
*
|
*
|
||||||
|
@ -5,10 +5,15 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.expressions
|
package space.kscience.kmath.expressions
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
|
||||||
public class ExpressionWithDefault<T>(
|
public class ExpressionWithDefault<T>(
|
||||||
private val origin: Expression<T>,
|
private val origin: Expression<T>,
|
||||||
private val defaultArgs: Map<Symbol, T>,
|
private val defaultArgs: Map<Symbol, T>,
|
||||||
) : Expression<T> {
|
) : Expression<T> {
|
||||||
|
override val type: SafeType<T>
|
||||||
|
get() = origin.type
|
||||||
|
|
||||||
override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments)
|
override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,6 +26,9 @@ public class DiffExpressionWithDefault<T>(
|
|||||||
private val defaultArgs: Map<Symbol, T>,
|
private val defaultArgs: Map<Symbol, T>,
|
||||||
) : DifferentiableExpression<T> {
|
) : DifferentiableExpression<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T>
|
||||||
|
get() = origin.type
|
||||||
|
|
||||||
override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments)
|
override fun invoke(arguments: Map<Symbol, T>): T = origin.invoke(defaultArgs + arguments)
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? =
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<T>? =
|
||||||
|
@ -23,26 +23,27 @@ public abstract class FunctionalExpressionAlgebra<T, out A : Algebra<T>>(
|
|||||||
/**
|
/**
|
||||||
* Builds an Expression of constant expression that does not depend on arguments.
|
* Builds an Expression of constant expression that does not depend on arguments.
|
||||||
*/
|
*/
|
||||||
override fun const(value: T): Expression<T> = Expression { value }
|
override fun const(value: T): Expression<T> = Expression(algebra.type) { value }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression to access a variable.
|
* Builds an Expression to access a variable.
|
||||||
*/
|
*/
|
||||||
override fun bindSymbolOrNull(value: String): Expression<T>? = Expression { arguments ->
|
override fun bindSymbolOrNull(value: String): Expression<T>? = Expression(algebra.type) { arguments ->
|
||||||
algebra.bindSymbolOrNull(value)
|
algebra.bindSymbolOrNull(value)
|
||||||
?: arguments[StringSymbol(value)]
|
?: arguments[StringSymbol(value)]
|
||||||
?: error("Symbol '$value' is not supported in $this")
|
?: error("Symbol '$value' is not supported in $this")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: Expression<T>, right: Expression<T>) -> Expression<T> =
|
override fun binaryOperationFunction(
|
||||||
{ left, right ->
|
operation: String,
|
||||||
Expression { arguments ->
|
): (left: Expression<T>, right: Expression<T>) -> Expression<T> = { left, right ->
|
||||||
|
Expression(algebra.type) { arguments ->
|
||||||
algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
|
algebra.binaryOperationFunction(operation)(left(arguments), right(arguments))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> = { arg ->
|
||||||
Expression { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
|
Expression(algebra.type) { arguments -> algebra.unaryOperation(operation, arg(arguments)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -124,7 +125,7 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
|
|||||||
super<FunctionalExpressionRing>.binaryOperationFunction(operation)
|
super<FunctionalExpressionRing>.binaryOperationFunction(operation)
|
||||||
|
|
||||||
override fun scale(a: Expression<T>, value: Double): Expression<T> = algebra {
|
override fun scale(a: Expression<T>, value: Double): Expression<T> = algebra {
|
||||||
Expression { args -> a(args) * value }
|
Expression(algebra.type) { args -> a(args) * value }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): Expression<T>? =
|
override fun bindSymbolOrNull(value: String): Expression<T>? =
|
||||||
|
@ -108,4 +108,4 @@ public fun <T> MST.interpret(algebra: Algebra<T>, vararg arguments: Pair<Symbol,
|
|||||||
* Interpret this [MST] as expression.
|
* Interpret this [MST] as expression.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
|
public fun <T : Any> MST.toExpression(algebra: Algebra<T>): Expression<T> =
|
||||||
Expression { arguments -> interpret(algebra, arguments) }
|
Expression(algebra.type) { arguments -> interpret(algebra, arguments) }
|
||||||
|
@ -243,12 +243,15 @@ public class SimpleAutoDiffExpression<T : Any, F : Field<T>>(
|
|||||||
public val field: F,
|
public val field: F,
|
||||||
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
public val function: SimpleAutoDiffField<T, F>.() -> AutoDiffValue<T>,
|
||||||
) : FirstDerivativeExpression<T>() {
|
) : FirstDerivativeExpression<T>() {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = this.field.type
|
||||||
|
|
||||||
override operator fun invoke(arguments: Map<Symbol, T>): T {
|
override operator fun invoke(arguments: Map<Symbol, T>): T {
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
return SimpleAutoDiffField(field, arguments).function().value
|
return SimpleAutoDiffField(field, arguments).function().value
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression { arguments ->
|
override fun derivativeOrNull(symbol: Symbol): Expression<T> = Expression(type) { arguments ->
|
||||||
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
//val bindings = arguments.entries.map { it.key.bind(it.value) }
|
||||||
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
|
val derivationResult = SimpleAutoDiffField(field, arguments).differentiate(function)
|
||||||
derivationResult.derivative(symbol)
|
derivationResult.derivative(symbol)
|
||||||
|
@ -11,7 +11,7 @@ package space.kscience.kmath.linear
|
|||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
*/
|
*/
|
||||||
public interface LinearSolver<T : Any> {
|
public interface LinearSolver<T> {
|
||||||
/**
|
/**
|
||||||
* Solve a dot x = b matrix equation and return x
|
* Solve a dot x = b matrix equation and return x
|
||||||
*/
|
*/
|
||||||
|
@ -5,10 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.attributes.Attributes
|
import space.kscience.attributes.*
|
||||||
import space.kscience.attributes.SafeType
|
|
||||||
import space.kscience.attributes.WithType
|
|
||||||
import space.kscience.attributes.withAttribute
|
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.BufferRingOps
|
import space.kscience.kmath.operations.BufferRingOps
|
||||||
@ -31,19 +28,13 @@ public typealias MutableMatrix<T> = MutableStructure2D<T>
|
|||||||
*/
|
*/
|
||||||
public typealias Point<T> = Buffer<T>
|
public typealias Point<T> = Buffer<T>
|
||||||
|
|
||||||
/**
|
|
||||||
* A marker interface for algebras that operate on matrices
|
|
||||||
* @param T type of matrix element
|
|
||||||
*/
|
|
||||||
public interface MatrixOperations<T> : WithType<T>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Basic operations on matrices and vectors.
|
* Basic operations on matrices and vectors.
|
||||||
*
|
*
|
||||||
* @param T the type of items in the matrices.
|
* @param T the type of items in the matrices.
|
||||||
* @param A the type of ring over [T].
|
* @param A the type of ring over [T].
|
||||||
*/
|
*/
|
||||||
public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> {
|
public interface LinearSpace<T, out A : Ring<T>> : MatrixScope<T> {
|
||||||
public val elementAlgebra: A
|
public val elementAlgebra: A
|
||||||
|
|
||||||
override val type: SafeType<T> get() = elementAlgebra.type
|
override val type: SafeType<T> get() = elementAlgebra.type
|
||||||
@ -177,10 +168,10 @@ public interface LinearSpace<T, out A : Ring<T>> : MatrixOperations<T> {
|
|||||||
/**
|
/**
|
||||||
* Compute an [attribute] value for given [structure]. Return null if the attribute could not be computed.
|
* Compute an [attribute] value for given [structure]. Return null if the attribute could not be computed.
|
||||||
*/
|
*/
|
||||||
public fun <V, A : StructureAttribute<V>> computeAttribute(structure: StructureND<*>, attribute: A): V? = null
|
public fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<T>, attribute: A): V? = null
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun <V, A : StructureAttribute<V>> StructureND<*>.getOrComputeAttribute(attribute: A): V? {
|
public fun <V, A : StructureAttribute<V>> Structure2D<T>.getOrComputeAttribute(attribute: A): V? {
|
||||||
return attributes[attribute] ?: computeAttribute(this, attribute)
|
return attributes[attribute] ?: computeAttribute(this, attribute)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,7 +216,7 @@ public inline operator fun <LS : LinearSpace<*, *>, R> LS.invoke(block: LS.() ->
|
|||||||
/**
|
/**
|
||||||
* Convert matrix to vector if it is possible.
|
* Convert matrix to vector if it is possible.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> Matrix<T>.asVector(): Point<T> =
|
public fun <T> Matrix<T>.asVector(): Point<T> =
|
||||||
if (this.colNum == 1) as1D()
|
if (this.colNum == 1) as1D()
|
||||||
else error("Can't convert matrix with more than one column to vector")
|
else error("Can't convert matrix with more than one column to vector")
|
||||||
|
|
||||||
@ -236,4 +227,4 @@ public fun <T : Any> Matrix<T>.asVector(): Point<T> =
|
|||||||
* @receiver a buffer.
|
* @receiver a buffer.
|
||||||
* @return the new matrix.
|
* @return the new matrix.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(type, size, 1) { i, _ -> get(i) }
|
public fun <T> Point<T>.asMatrix(): VirtualMatrix<T> = VirtualMatrix(type, size, 1) { i, _ -> get(i) }
|
@ -9,12 +9,20 @@ package space.kscience.kmath.linear
|
|||||||
|
|
||||||
import space.kscience.attributes.Attributes
|
import space.kscience.attributes.Attributes
|
||||||
import space.kscience.attributes.PolymorphicAttribute
|
import space.kscience.attributes.PolymorphicAttribute
|
||||||
import space.kscience.attributes.SafeType
|
|
||||||
import space.kscience.attributes.safeTypeOf
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
|
|
||||||
|
public interface LupDecomposition<T> {
|
||||||
|
public val linearSpace: LinearSpace<T, Field<T>>
|
||||||
|
public val elementAlgebra: Field<T> get() = linearSpace.elementAlgebra
|
||||||
|
|
||||||
|
public val pivot: IntBuffer
|
||||||
|
public val l: Matrix<T>
|
||||||
|
public val u: Matrix<T>
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where
|
* Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where
|
||||||
* *a* is the owning matrix.
|
* *a* is the owning matrix.
|
||||||
@ -22,15 +30,14 @@ import space.kscience.kmath.structures.*
|
|||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
* @param lu combined L and U matrix
|
* @param lu combined L and U matrix
|
||||||
*/
|
*/
|
||||||
public class LupDecomposition<T>(
|
public class GenericLupDecomposition<T>(
|
||||||
public val linearSpace: LinearSpace<T, Ring<T>>,
|
override val linearSpace: LinearSpace<T, Field<T>>,
|
||||||
private val lu: Matrix<T>,
|
private val lu: Matrix<T>,
|
||||||
public val pivot: IntBuffer,
|
override val pivot: IntBuffer,
|
||||||
private val even: Boolean,
|
private val even: Boolean,
|
||||||
) {
|
) : LupDecomposition<T> {
|
||||||
public val elementAlgebra: Ring<T> get() = linearSpace.elementAlgebra
|
|
||||||
|
|
||||||
public val l: Matrix<T>
|
override val l: Matrix<T>
|
||||||
get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(LowerTriangular)) { i, j ->
|
get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(LowerTriangular)) { i, j ->
|
||||||
when {
|
when {
|
||||||
j < i -> lu[i, j]
|
j < i -> lu[i, j]
|
||||||
@ -39,7 +46,7 @@ public class LupDecomposition<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public val u: Matrix<T>
|
override val u: Matrix<T>
|
||||||
get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(UpperTriangular)) { i, j ->
|
get() = VirtualMatrix(lu.type, lu.rowNum, lu.colNum, attributes = Attributes(UpperTriangular)) { i, j ->
|
||||||
if (j >= i) lu[i, j] else elementAlgebra.zero
|
if (j >= i) lu[i, j] else elementAlgebra.zero
|
||||||
}
|
}
|
||||||
@ -55,13 +62,12 @@ public class LupDecomposition<T>(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class LupDecompositionAttribute<T> :
|
||||||
public class LupDecompositionAttribute<T>(type: SafeType<LupDecomposition<T>>) :
|
PolymorphicAttribute<LupDecomposition<T>>(safeTypeOf()),
|
||||||
PolymorphicAttribute<LupDecomposition<T>>(type),
|
|
||||||
MatrixAttribute<LupDecomposition<T>>
|
MatrixAttribute<LupDecomposition<T>>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.LUP: LupDecompositionAttribute<T>
|
public val <T> MatrixScope<T>.LUP: LupDecompositionAttribute<T>
|
||||||
get() = LupDecompositionAttribute(safeTypeOf())
|
get() = LupDecompositionAttribute()
|
||||||
|
|
||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
|
internal fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.abs(value: T): T =
|
||||||
@ -79,7 +85,7 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
val pivot = IntArray(matrix.rowNum)
|
val pivot = IntArray(matrix.rowNum)
|
||||||
|
|
||||||
//TODO just waits for multi-receivers
|
//TODO just waits for multi-receivers
|
||||||
with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)){
|
with(BufferAccessor2D(matrix.rowNum, matrix.colNum, elementAlgebra.bufferFactory)) {
|
||||||
|
|
||||||
val lu = create(matrix)
|
val lu = create(matrix)
|
||||||
|
|
||||||
@ -142,18 +148,17 @@ public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lup(
|
|||||||
for (row in col + 1 until m) lu[row, col] /= luDiag
|
for (row in col + 1 until m) lu[row, col] /= luDiag
|
||||||
}
|
}
|
||||||
|
|
||||||
return LupDecomposition(this@lup, lu.toStructure2D(), pivot.asBuffer(), even)
|
return GenericLupDecomposition(this@lup, lu.toStructure2D(), pivot.asBuffer(), even)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public fun LinearSpace<Double, Float64Field>.lup(
|
public fun LinearSpace<Double, Float64Field>.lup(
|
||||||
matrix: Matrix<Double>,
|
matrix: Matrix<Double>,
|
||||||
singularityThreshold: Double = 1e-11,
|
singularityThreshold: Double = 1e-11,
|
||||||
): LupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
|
): LupDecomposition<Double> = lup(matrix) { it < singularityThreshold }
|
||||||
|
|
||||||
internal fun <T : Any, A : Field<T>> LinearSpace<T, A>.solve(
|
internal fun <T> LinearSpace<T, Field<T>>.solve(
|
||||||
lup: LupDecomposition<T>,
|
lup: LupDecomposition<T>,
|
||||||
matrix: Matrix<T>,
|
matrix: Matrix<T>,
|
||||||
): Matrix<T> {
|
): Matrix<T> {
|
||||||
@ -205,7 +210,7 @@ internal fun <T : Any, A : Field<T>> LinearSpace<T, A>.solve(
|
|||||||
* Produce a generic solver based on LUP decomposition
|
* Produce a generic solver based on LUP decomposition
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public fun <T : Comparable<T>, F : Field<T>> LinearSpace<T, F>.lupSolver(
|
public fun <T : Comparable<T>> LinearSpace<T, Field<T>>.lupSolver(
|
||||||
singularityCheck: (T) -> Boolean,
|
singularityCheck: (T) -> Boolean,
|
||||||
): LinearSolver<T> = object : LinearSolver<T> {
|
): LinearSolver<T> = object : LinearSolver<T> {
|
||||||
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
override fun solve(a: Matrix<T>, b: Matrix<T>): Matrix<T> {
|
||||||
|
@ -35,7 +35,7 @@ public val <T : Any> Matrix<T>.origin: Matrix<T>
|
|||||||
/**
|
/**
|
||||||
* Add a single feature to a [Matrix]
|
* Add a single feature to a [Matrix]
|
||||||
*/
|
*/
|
||||||
public fun <T : Any, A : Attribute<T>> Matrix<T>.withAttribute(
|
public fun <T, A : Attribute<T>> Matrix<T>.withAttribute(
|
||||||
attribute: A,
|
attribute: A,
|
||||||
attrValue: T,
|
attrValue: T,
|
||||||
): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
||||||
@ -44,7 +44,7 @@ public fun <T : Any, A : Attribute<T>> Matrix<T>.withAttribute(
|
|||||||
MatrixWrapper(this, Attributes(attribute, attrValue))
|
MatrixWrapper(this, Attributes(attribute, attrValue))
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T : Any, A : Attribute<Unit>> Matrix<T>.withAttribute(
|
public fun <T, A : Attribute<Unit>> Matrix<T>.withAttribute(
|
||||||
attribute: A,
|
attribute: A,
|
||||||
): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
||||||
MatrixWrapper(origin, attributes.withAttribute(attribute))
|
MatrixWrapper(origin, attributes.withAttribute(attribute))
|
||||||
@ -55,7 +55,7 @@ public fun <T : Any, A : Attribute<Unit>> Matrix<T>.withAttribute(
|
|||||||
/**
|
/**
|
||||||
* Modify matrix attributes
|
* Modify matrix attributes
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attributes): MatrixWrapper<T> =
|
public fun <T> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attributes): MatrixWrapper<T> =
|
||||||
if (this is MatrixWrapper) {
|
if (this is MatrixWrapper) {
|
||||||
MatrixWrapper(origin, modifier(attributes))
|
MatrixWrapper(origin, modifier(attributes))
|
||||||
} else {
|
} else {
|
||||||
@ -65,7 +65,7 @@ public fun <T : Any> Matrix<T>.modifyAttributes(modifier: (Attributes) -> Attrib
|
|||||||
/**
|
/**
|
||||||
* Diagonal matrix of ones. The matrix is virtual, no actual matrix is created.
|
* Diagonal matrix of ones. The matrix is virtual, no actual matrix is created.
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> LinearSpace<T, Ring<T>>.one(
|
public fun <T> LinearSpace<T, Ring<T>>.one(
|
||||||
rows: Int,
|
rows: Int,
|
||||||
columns: Int,
|
columns: Int,
|
||||||
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { i, j ->
|
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { i, j ->
|
||||||
@ -76,7 +76,7 @@ public fun <T : Any> LinearSpace<T, Ring<T>>.one(
|
|||||||
/**
|
/**
|
||||||
* A virtual matrix of zeroes
|
* A virtual matrix of zeroes
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> LinearSpace<T, Ring<T>>.zero(
|
public fun <T> LinearSpace<T, Ring<T>>.zero(
|
||||||
rows: Int,
|
rows: Int,
|
||||||
columns: Int,
|
columns: Int,
|
||||||
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { _, _ ->
|
): MatrixWrapper<T> = VirtualMatrix(type, rows, columns) { _, _ ->
|
||||||
|
@ -10,6 +10,13 @@ package space.kscience.kmath.linear
|
|||||||
import space.kscience.attributes.*
|
import space.kscience.attributes.*
|
||||||
import space.kscience.kmath.nd.StructureAttribute
|
import space.kscience.kmath.nd.StructureAttribute
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A marker interface for algebras that operate on matrices
|
||||||
|
* @param T type of matrix element
|
||||||
|
*/
|
||||||
|
public interface MatrixScope<T> : AttributeScope<Matrix<T>>, WithType<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A marker interface representing some properties of matrices or additional transformations of them. Features are used
|
* A marker interface representing some properties of matrices or additional transformations of them. Features are used
|
||||||
* to optimize matrix operations performance in some cases or retrieve the APIs.
|
* to optimize matrix operations performance in some cases or retrieve the APIs.
|
||||||
@ -38,11 +45,11 @@ public object IsUnit : IsDiagonal
|
|||||||
*
|
*
|
||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
*/
|
*/
|
||||||
public class Inverted<T>(type: SafeType<Matrix<T>>) :
|
public class Inverted<T>() :
|
||||||
PolymorphicAttribute<Matrix<T>>(type),
|
PolymorphicAttribute<Matrix<T>>(safeTypeOf()),
|
||||||
MatrixAttribute<Matrix<T>>
|
MatrixAttribute<Matrix<T>>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.Inverted: Inverted<T> get() = Inverted(safeTypeOf())
|
public val <T> MatrixScope<T>.Inverted: Inverted<T> get() = Inverted()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Matrices with this feature can compute their determinant.
|
* Matrices with this feature can compute their determinant.
|
||||||
@ -53,7 +60,7 @@ public class Determinant<T>(type: SafeType<T>) :
|
|||||||
PolymorphicAttribute<T>(type),
|
PolymorphicAttribute<T>(type),
|
||||||
MatrixAttribute<T>
|
MatrixAttribute<T>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.Determinant: Determinant<T> get() = Determinant(type)
|
public val <T> MatrixScope<T>.Determinant: Determinant<T> get() = Determinant(type)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Matrices with this feature are lower triangular ones.
|
* Matrices with this feature are lower triangular ones.
|
||||||
@ -77,11 +84,11 @@ public data class LUDecomposition<T>(val l: Matrix<T>, val u: Matrix<T>)
|
|||||||
*
|
*
|
||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
*/
|
*/
|
||||||
public class LuDecompositionAttribute<T>(type: SafeType<LUDecomposition<T>>) :
|
public class LuDecompositionAttribute<T> :
|
||||||
PolymorphicAttribute<LUDecomposition<T>>(type),
|
PolymorphicAttribute<LUDecomposition<T>>(safeTypeOf()),
|
||||||
MatrixAttribute<LUDecomposition<T>>
|
MatrixAttribute<LUDecomposition<T>>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.LU: LuDecompositionAttribute<T> get() = LuDecompositionAttribute(safeTypeOf())
|
public val <T> MatrixScope<T>.LU: LuDecompositionAttribute<T> get() = LuDecompositionAttribute()
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -108,12 +115,12 @@ public interface QRDecomposition<out T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
*/
|
*/
|
||||||
public class QRDecompositionAttribute<T>(type: SafeType<QRDecomposition<T>>) :
|
public class QRDecompositionAttribute<T>() :
|
||||||
PolymorphicAttribute<QRDecomposition<T>>(type),
|
PolymorphicAttribute<QRDecomposition<T>>(safeTypeOf()),
|
||||||
MatrixAttribute<QRDecomposition<T>>
|
MatrixAttribute<QRDecomposition<T>>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.QR: QRDecompositionAttribute<T>
|
public val <T> MatrixScope<T>.QR: QRDecompositionAttribute<T>
|
||||||
get() = QRDecompositionAttribute(safeTypeOf())
|
get() = QRDecompositionAttribute()
|
||||||
|
|
||||||
public interface CholeskyDecomposition<T> {
|
public interface CholeskyDecomposition<T> {
|
||||||
/**
|
/**
|
||||||
@ -128,12 +135,12 @@ public interface CholeskyDecomposition<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
*/
|
*/
|
||||||
public class CholeskyDecompositionAttribute<T>(type: SafeType<CholeskyDecomposition<T>>) :
|
public class CholeskyDecompositionAttribute<T> :
|
||||||
PolymorphicAttribute<CholeskyDecomposition<T>>(type),
|
PolymorphicAttribute<CholeskyDecomposition<T>>(safeTypeOf()),
|
||||||
MatrixAttribute<CholeskyDecomposition<T>>
|
MatrixAttribute<CholeskyDecomposition<T>>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.Cholesky: CholeskyDecompositionAttribute<T>
|
public val <T> MatrixScope<T>.Cholesky: CholeskyDecompositionAttribute<T>
|
||||||
get() = CholeskyDecompositionAttribute(safeTypeOf())
|
get() = CholeskyDecompositionAttribute()
|
||||||
|
|
||||||
public interface SingularValueDecomposition<T> {
|
public interface SingularValueDecomposition<T> {
|
||||||
/**
|
/**
|
||||||
@ -163,12 +170,11 @@ public interface SingularValueDecomposition<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of matrices' items.
|
* @param T the type of matrices' items.
|
||||||
*/
|
*/
|
||||||
public class SVDAttribute<T>(type: SafeType<SingularValueDecomposition<T>>) :
|
public class SVDAttribute<T>() :
|
||||||
PolymorphicAttribute<SingularValueDecomposition<T>>(type),
|
PolymorphicAttribute<SingularValueDecomposition<T>>(safeTypeOf()),
|
||||||
MatrixAttribute<SingularValueDecomposition<T>>
|
MatrixAttribute<SingularValueDecomposition<T>>
|
||||||
|
|
||||||
public val <T> MatrixOperations<T>.SVD: SVDAttribute<T>
|
public val <T> MatrixScope<T>.SVD: SVDAttribute<T> get() = SVDAttribute()
|
||||||
get() = SVDAttribute(safeTypeOf())
|
|
||||||
|
|
||||||
|
|
||||||
//TODO add sparse matrix feature
|
//TODO add sparse matrix feature
|
||||||
|
@ -39,5 +39,5 @@ public abstract class EjmlLinearSpace<T : Any, out A : Ring<T>, out M : org.ejml
|
|||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun EjmlMatrix<T, *>.inverted(): Matrix<Double> =
|
public fun EjmlMatrix<T, *>.inverted(): Matrix<Double> =
|
||||||
attributeForOrNull(this, Float64Field.linearSpace.Inverted)
|
computeAttribute(this, Float64Field.linearSpace.Inverted)!!
|
||||||
}
|
}
|
||||||
|
@ -19,9 +19,13 @@ import org.ejml.sparse.csc.factory.DecompositionFactory_DSCC
|
|||||||
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
|
import org.ejml.sparse.csc.factory.DecompositionFactory_FSCC
|
||||||
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
|
import org.ejml.sparse.csc.factory.LinearSolverFactory_DSCC
|
||||||
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
|
import org.ejml.sparse.csc.factory.LinearSolverFactory_FSCC
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.linear.*
|
import space.kscience.kmath.linear.*
|
||||||
import space.kscience.kmath.linear.Matrix
|
import space.kscience.kmath.linear.Matrix
|
||||||
|
import space.kscience.kmath.nd.Structure2D
|
||||||
|
import space.kscience.kmath.nd.StructureAttribute
|
||||||
import space.kscience.kmath.nd.StructureFeature
|
import space.kscience.kmath.nd.StructureFeature
|
||||||
import space.kscience.kmath.operations.Float32Field
|
import space.kscience.kmath.operations.Float32Field
|
||||||
import space.kscience.kmath.operations.Float64Field
|
import space.kscience.kmath.operations.Float64Field
|
||||||
@ -39,6 +43,8 @@ public class EjmlDoubleVector<out M : DMatrix>(override val origin: M) : EjmlVec
|
|||||||
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
|
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val type: SafeType<Double> get() = safeTypeOf()
|
||||||
|
|
||||||
override operator fun get(index: Int): Double = origin[0, index]
|
override operator fun get(index: Int): Double = origin[0, index]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,6 +56,8 @@ public class EjmlFloatVector<out M : FMatrix>(override val origin: M) : EjmlVect
|
|||||||
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
|
require(origin.numRows == 1) { "The origin matrix must have only one row to form a vector" }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val type: SafeType<Float> get() = safeTypeOf()
|
||||||
|
|
||||||
override operator fun get(index: Int): Float = origin[0, index]
|
override operator fun get(index: Int): Float = origin[0, index]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,6 +65,8 @@ public class EjmlFloatVector<out M : FMatrix>(override val origin: M) : EjmlVect
|
|||||||
* [EjmlMatrix] specialization for [Double].
|
* [EjmlMatrix] specialization for [Double].
|
||||||
*/
|
*/
|
||||||
public class EjmlDoubleMatrix<out M : DMatrix>(override val origin: M) : EjmlMatrix<Double, M>(origin) {
|
public class EjmlDoubleMatrix<out M : DMatrix>(override val origin: M) : EjmlMatrix<Double, M>(origin) {
|
||||||
|
override val type: SafeType<Double> get() = safeTypeOf()
|
||||||
|
|
||||||
override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,9 +74,12 @@ public class EjmlDoubleMatrix<out M : DMatrix>(override val origin: M) : EjmlMat
|
|||||||
* [EjmlMatrix] specialization for [Float].
|
* [EjmlMatrix] specialization for [Float].
|
||||||
*/
|
*/
|
||||||
public class EjmlFloatMatrix<out M : FMatrix>(override val origin: M) : EjmlMatrix<Float, M>(origin) {
|
public class EjmlFloatMatrix<out M : FMatrix>(override val origin: M) : EjmlMatrix<Float, M>(origin) {
|
||||||
|
override val type: SafeType<Float> get() = safeTypeOf()
|
||||||
|
|
||||||
override operator fun get(i: Int, j: Int): Float = origin[i, j]
|
override operator fun get(i: Int, j: Int): Float = origin[i, j]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and
|
* [EjmlLinearSpace] implementation based on [CommonOps_DDRM], [DecompositionFactory_DDRM] operations and
|
||||||
* [DMatrixRMaj] matrices.
|
* [DMatrixRMaj] matrices.
|
||||||
@ -77,7 +90,7 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
|
|||||||
*/
|
*/
|
||||||
override val elementAlgebra: Float64Field get() = Float64Field
|
override val elementAlgebra: Float64Field get() = Float64Field
|
||||||
|
|
||||||
override val elementType: KType get() = typeOf<Double>()
|
override val type: SafeType<Double> get() = safeTypeOf()
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixRMaj> = when {
|
override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixRMaj> = when {
|
||||||
@ -205,6 +218,18 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
|
|||||||
|
|
||||||
override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this
|
override fun Double.times(v: Point<Double>): EjmlDoubleVector<DMatrixRMaj> = v * this
|
||||||
|
|
||||||
|
override fun <V, A : StructureAttribute<V>> computeAttribute(structure: Structure2D<Double>, attribute: A): V? {
|
||||||
|
val origin = structure.toEjml().origin
|
||||||
|
return when(attribute){
|
||||||
|
Inverted -> {
|
||||||
|
val res = origin.copy()
|
||||||
|
CommonOps_DDRM.invert(res)
|
||||||
|
res.wrapMatrix()
|
||||||
|
}
|
||||||
|
else->
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
override fun <F : StructureFeature> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
|
override fun <F : StructureFeature> computeFeature(structure: Matrix<Double>, type: KClass<out F>): F? {
|
||||||
structure.getFeature(type)?.let { return it }
|
structure.getFeature(type)?.let { return it }
|
||||||
@ -305,6 +330,8 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, Float64Field, DMatri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import org.checkerframework.checker.guieffect.qual.SafeType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and
|
* [EjmlLinearSpace] implementation based on [CommonOps_FDRM], [DecompositionFactory_FDRM] operations and
|
||||||
* [FMatrixRMaj] matrices.
|
* [FMatrixRMaj] matrices.
|
||||||
@ -315,7 +342,7 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, Float32Field, FMatrix
|
|||||||
*/
|
*/
|
||||||
override val elementAlgebra: Float32Field get() = Float32Field
|
override val elementAlgebra: Float32Field get() = Float32Field
|
||||||
|
|
||||||
override val elementType: KType get() = typeOf<Float>()
|
override val type: SafeType<Float> get() = safeTypeOf()
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixRMaj> = when {
|
override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixRMaj> = when {
|
||||||
@ -543,6 +570,8 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, Float32Field, FMatrix
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import org.checkerframework.checker.guieffect.qual.SafeType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and
|
* [EjmlLinearSpace] implementation based on [CommonOps_DSCC], [DecompositionFactory_DSCC] operations and
|
||||||
* [DMatrixSparseCSC] matrices.
|
* [DMatrixSparseCSC] matrices.
|
||||||
@ -553,7 +582,7 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, Float64Field, DMatri
|
|||||||
*/
|
*/
|
||||||
override val elementAlgebra: Float64Field get() = Float64Field
|
override val elementAlgebra: Float64Field get() = Float64Field
|
||||||
|
|
||||||
override val elementType: KType get() = typeOf<Double>()
|
override val type: SafeType<Double> get() = safeTypeOf()
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixSparseCSC> = when {
|
override fun Matrix<Double>.toEjml(): EjmlDoubleMatrix<DMatrixSparseCSC> = when {
|
||||||
@ -776,6 +805,8 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, Float64Field, DMatri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
import org.checkerframework.checker.guieffect.qual.SafeType
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and
|
* [EjmlLinearSpace] implementation based on [CommonOps_FSCC], [DecompositionFactory_FSCC] operations and
|
||||||
* [FMatrixSparseCSC] matrices.
|
* [FMatrixSparseCSC] matrices.
|
||||||
@ -786,7 +817,7 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace<Float, Float32Field, FMatrix
|
|||||||
*/
|
*/
|
||||||
override val elementAlgebra: Float32Field get() = Float32Field
|
override val elementAlgebra: Float32Field get() = Float32Field
|
||||||
|
|
||||||
override val elementType: KType get() = typeOf<Float>()
|
override val type: SafeType<Float> get() = safeTypeOf()
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixSparseCSC> = when {
|
override fun Matrix<Float>.toEjml(): EjmlFloatMatrix<FMatrixSparseCSC> = when {
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
*/
|
*/
|
||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
import space.kscience.attributes.TypedAttributesBuilder
|
import space.kscience.attributes.AttributesBuilder
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
@ -92,7 +92,7 @@ public inline fun <reified T : Any> GaussIntegrator<T>.integrate(
|
|||||||
range: ClosedRange<Double>,
|
range: ClosedRange<Double>,
|
||||||
order: Int = 10,
|
order: Int = 10,
|
||||||
intervals: Int = 10,
|
intervals: Int = 10,
|
||||||
attributesBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
attributesBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
||||||
noinline function: (Double) -> T,
|
noinline function: (Double) -> T,
|
||||||
): UnivariateIntegrand<T> {
|
): UnivariateIntegrand<T> {
|
||||||
require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" }
|
require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" }
|
||||||
|
@ -30,7 +30,7 @@ public sealed class IntegrandValue<T> private constructor(): IntegrandAttribute<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> TypedAttributesBuilder<Integrand<T>>.value(value: T) {
|
public fun <T> AttributesBuilder<Integrand<T>>.value(value: T) {
|
||||||
IntegrandValue.forType<T>().invoke(value)
|
IntegrandValue.forType<T>().invoke(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,10 +25,10 @@ public fun <T, A : Any> MultivariateIntegrand<T>.withAttribute(
|
|||||||
): MultivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
|
): MultivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
|
||||||
|
|
||||||
public fun <T> MultivariateIntegrand<T>.withAttributes(
|
public fun <T> MultivariateIntegrand<T>.withAttributes(
|
||||||
block: TypedAttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
|
block: AttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
|
||||||
): MultivariateIntegrand<T> = withAttributes(attributes.modify(block))
|
): MultivariateIntegrand<T> = withAttributes(attributes.modify(block))
|
||||||
|
|
||||||
public inline fun <reified T : Any> MultivariateIntegrand(
|
public inline fun <reified T : Any> MultivariateIntegrand(
|
||||||
attributeBuilder: TypedAttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
|
attributeBuilder: AttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
|
||||||
noinline function: (Point<T>) -> T,
|
noinline function: (Point<T>) -> T,
|
||||||
): MultivariateIntegrand<T> = MultivariateIntegrand(safeTypeOf<T>(), Attributes(attributeBuilder), function)
|
): MultivariateIntegrand<T> = MultivariateIntegrand(safeTypeOf<T>(), Attributes(attributeBuilder), function)
|
||||||
|
@ -26,11 +26,11 @@ public fun <T, A : Any> UnivariateIntegrand<T>.withAttribute(
|
|||||||
): UnivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
|
): UnivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
|
||||||
|
|
||||||
public fun <T> UnivariateIntegrand<T>.withAttributes(
|
public fun <T> UnivariateIntegrand<T>.withAttributes(
|
||||||
block: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
block: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
||||||
): UnivariateIntegrand<T> = withAttributes(attributes.modify(block))
|
): UnivariateIntegrand<T> = withAttributes(attributes.modify(block))
|
||||||
|
|
||||||
public inline fun <reified T : Any> UnivariateIntegrand(
|
public inline fun <reified T : Any> UnivariateIntegrand(
|
||||||
attributeBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
attributeBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
||||||
noinline function: (Double) -> T,
|
noinline function: (Double) -> T,
|
||||||
): UnivariateIntegrand<T> = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function)
|
): UnivariateIntegrand<T> = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function)
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ public class UnivariateIntegrandRanges(public val ranges: List<Pair<ClosedRange<
|
|||||||
|
|
||||||
public object UnivariateIntegrationNodes : IntegrandAttribute<Buffer<Double>>
|
public object UnivariateIntegrationNodes : IntegrandAttribute<Buffer<Double>>
|
||||||
|
|
||||||
public fun TypedAttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(vararg nodes: Double) {
|
public fun AttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(vararg nodes: Double) {
|
||||||
UnivariateIntegrationNodes(Float64Buffer(nodes))
|
UnivariateIntegrationNodes(Float64Buffer(nodes))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ public fun TypedAttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(varar
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
|
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
|
||||||
attributesBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
attributesBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
|
||||||
noinline function: (Double) -> T,
|
noinline function: (Double) -> T,
|
||||||
): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(attributesBuilder, function))
|
): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(attributesBuilder, function))
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
|
|||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
|
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
|
||||||
range: ClosedRange<Double>,
|
range: ClosedRange<Double>,
|
||||||
attributeBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit = {},
|
attributeBuilder: AttributesBuilder<UnivariateIntegrand<T>>.() -> Unit = {},
|
||||||
noinline function: (Double) -> T,
|
noinline function: (Double) -> T,
|
||||||
): UnivariateIntegrand<T> {
|
): UnivariateIntegrand<T> {
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ repositories {
|
|||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
maturity = space.kscience.gradle.Maturity.PROTOTYPE
|
maturity = space.kscience.gradle.Maturity.DEPRECATED
|
||||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
|
|
||||||
feature("jafama-double", "src/main/kotlin/space/kscience/kmath/jafama/") {
|
feature("jafama-double", "src/main/kotlin/space/kscience/kmath/jafama/") {
|
||||||
|
@ -7,16 +7,17 @@ package space.kscience.kmath.jafama
|
|||||||
|
|
||||||
import net.jafama.FastMath
|
import net.jafama.FastMath
|
||||||
import net.jafama.StrictFastMath
|
import net.jafama.StrictFastMath
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.operations.Norm
|
import space.kscience.kmath.structures.MutableBufferFactory
|
||||||
import space.kscience.kmath.operations.PowerOperations
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for [Double] (using FastMath) without boxing. Does not produce appropriate field element.
|
* A field for [Double] (using FastMath) without boxing. Does not produce appropriate field element.
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
|
public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<Double> get() = DoubleField.bufferFactory
|
||||||
|
|
||||||
override inline val zero: Double get() = 0.0
|
override inline val zero: Double get() = 0.0
|
||||||
override inline val one: Double get() = 1.0
|
override inline val one: Double get() = 1.0
|
||||||
|
|
||||||
@ -68,6 +69,9 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
|
|||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
|
public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
|
||||||
|
|
||||||
|
override val bufferFactory: MutableBufferFactory<Double> get() = DoubleField.bufferFactory
|
||||||
|
|
||||||
override inline val zero: Double get() = 0.0
|
override inline val zero: Double get() = 0.0
|
||||||
override inline val one: Double get() = 1.0
|
override inline val one: Double get() = 1.0
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ public class MultikDoubleAlgebra(
|
|||||||
) : MultikDivisionTensorAlgebra<Double, Float64Field>(multikEngine),
|
) : MultikDivisionTensorAlgebra<Double, Float64Field>(multikEngine),
|
||||||
TrigonometricOperations<StructureND<Double>>, ExponentialOperations<StructureND<Double>> {
|
TrigonometricOperations<StructureND<Double>>, ExponentialOperations<StructureND<Double>> {
|
||||||
override val elementAlgebra: Float64Field get() = Float64Field
|
override val elementAlgebra: Float64Field get() = Float64Field
|
||||||
override val type: DataType get() = DataType.DoubleDataType
|
override val dataType: DataType get() = DataType.DoubleDataType
|
||||||
|
|
||||||
override fun sin(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.sin(arg.asMultik().array).wrap()
|
override fun sin(arg: StructureND<Double>): MultikTensor<Double> = multikMath.mathEx.sin(arg.asMultik().array).wrap()
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ public class MultikFloatAlgebra(
|
|||||||
multikEngine: Engine
|
multikEngine: Engine
|
||||||
) : MultikDivisionTensorAlgebra<Float, Float32Field>(multikEngine) {
|
) : MultikDivisionTensorAlgebra<Float, Float32Field>(multikEngine) {
|
||||||
override val elementAlgebra: Float32Field get() = Float32Field
|
override val elementAlgebra: Float32Field get() = Float32Field
|
||||||
override val type: DataType get() = DataType.FloatDataType
|
override val dataType: DataType get() = DataType.FloatDataType
|
||||||
|
|
||||||
override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ public class MultikIntAlgebra(
|
|||||||
multikEngine: Engine
|
multikEngine: Engine
|
||||||
) : MultikTensorAlgebra<Int, Int32Ring>(multikEngine) {
|
) : MultikTensorAlgebra<Int, Int32Ring>(multikEngine) {
|
||||||
override val elementAlgebra: Int32Ring get() = Int32Ring
|
override val elementAlgebra: Int32Ring get() = Int32Ring
|
||||||
override val type: DataType get() = DataType.IntDataType
|
override val dataType: DataType get() = DataType.IntDataType
|
||||||
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ public class MultikLongAlgebra(
|
|||||||
multikEngine: Engine
|
multikEngine: Engine
|
||||||
) : MultikTensorAlgebra<Long, Int64Ring>(multikEngine) {
|
) : MultikTensorAlgebra<Long, Int64Ring>(multikEngine) {
|
||||||
override val elementAlgebra: Int64Ring get() = Int64Ring
|
override val elementAlgebra: Int64Ring get() = Int64Ring
|
||||||
override val type: DataType get() = DataType.LongDataType
|
override val dataType: DataType get() = DataType.LongDataType
|
||||||
|
|
||||||
override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ public class MultikShortAlgebra(
|
|||||||
multikEngine: Engine
|
multikEngine: Engine
|
||||||
) : MultikTensorAlgebra<Short, Int16Ring>(multikEngine) {
|
) : MultikTensorAlgebra<Short, Int16Ring>(multikEngine) {
|
||||||
override val elementAlgebra: Int16Ring get() = Int16Ring
|
override val elementAlgebra: Int16Ring get() = Int16Ring
|
||||||
override val type: DataType get() = DataType.ShortDataType
|
override val dataType: DataType get() = DataType.ShortDataType
|
||||||
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,13 +6,33 @@
|
|||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
|
import space.kscience.attributes.safeTypeOf
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
|
import space.kscience.kmath.complex.ComplexField
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
|
public val DataType.type: SafeType<*>
|
||||||
|
get() = when (this) {
|
||||||
|
DataType.ByteDataType -> ByteRing.type
|
||||||
|
DataType.ShortDataType -> ShortRing.type
|
||||||
|
DataType.IntDataType -> IntRing.type
|
||||||
|
DataType.LongDataType -> LongRing.type
|
||||||
|
DataType.FloatDataType -> Float32Field.type
|
||||||
|
DataType.DoubleDataType -> Float64Field.type
|
||||||
|
DataType.ComplexFloatDataType -> safeTypeOf<Pair<Float, Float>>()
|
||||||
|
DataType.ComplexDoubleDataType -> ComplexField.type
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
override val type: SafeType<T> get() = array.dtype.type as SafeType<T>
|
||||||
|
|
||||||
override val shape: ShapeND get() = ShapeND(array.shape)
|
override val shape: ShapeND get() = ShapeND(array.shape)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
|
@ -26,7 +26,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
private val multikEngine: Engine,
|
private val multikEngine: Engine,
|
||||||
) : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
|
) : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
public abstract val type: DataType
|
public abstract val dataType: DataType
|
||||||
|
|
||||||
protected val multikMath: Math = multikEngine.getMath()
|
protected val multikMath: Math = multikEngine.getMath()
|
||||||
protected val multikLinAl: LinAlg = multikEngine.getLinAlg()
|
protected val multikLinAl: LinAlg = multikEngine.getLinAlg()
|
||||||
@ -35,7 +35,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
@OptIn(UnsafeKMathAPI::class)
|
@OptIn(UnsafeKMathAPI::class)
|
||||||
override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||||
val strides = ColumnStrides(shape)
|
val strides = ColumnStrides(shape)
|
||||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
val memoryView = initMemoryView<T>(strides.linearSize, dataType)
|
||||||
strides.asSequence().forEachIndexed { linearIndex, tensorIndex ->
|
strides.asSequence().forEachIndexed { linearIndex, tensorIndex ->
|
||||||
memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex)
|
memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex)
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
|
|
||||||
@OptIn(PerformancePitfall::class, UnsafeKMathAPI::class)
|
@OptIn(PerformancePitfall::class, UnsafeKMathAPI::class)
|
||||||
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
||||||
val data = initMemoryView<T>(array.size, type)
|
val data = initMemoryView<T>(array.size, dataType)
|
||||||
var count = 0
|
var count = 0
|
||||||
for (el in array) data[count++] = elementAlgebra.transform(el)
|
for (el in array) data[count++] = elementAlgebra.transform(el)
|
||||||
NDArray(data, shape = shape.asArray(), dim = array.dim).wrap()
|
NDArray(data, shape = shape.asArray(), dim = array.dim).wrap()
|
||||||
@ -58,7 +58,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
||||||
if (this is MultikTensor) {
|
if (this is MultikTensor) {
|
||||||
val array = asMultik().array
|
val array = asMultik().array
|
||||||
val data = initMemoryView<T>(array.size, type)
|
val data = initMemoryView<T>(array.size, dataType)
|
||||||
val indexIter = array.multiIndices.iterator()
|
val indexIter = array.multiIndices.iterator()
|
||||||
var index = 0
|
var index = 0
|
||||||
for (item in array) {
|
for (item in array) {
|
||||||
@ -95,7 +95,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||||
val leftArray = left.asMultik().array
|
val leftArray = left.asMultik().array
|
||||||
val rightArray = right.asMultik().array
|
val rightArray = right.asMultik().array
|
||||||
val data = initMemoryView<T>(leftArray.size, type)
|
val data = initMemoryView<T>(leftArray.size, dataType)
|
||||||
var counter = 0
|
var counter = 0
|
||||||
val leftIterator = leftArray.iterator()
|
val leftIterator = leftArray.iterator()
|
||||||
val rightIterator = rightArray.iterator()
|
val rightIterator = rightArray.iterator()
|
||||||
@ -114,7 +114,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
|||||||
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
|
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
|
||||||
this
|
this
|
||||||
} else {
|
} else {
|
||||||
val res = mk.zeros<T, DN>(shape.asArray(), type).asDNArray()
|
val res = mk.zeros<T, DN>(shape.asArray(), dataType).asDNArray()
|
||||||
for (index in res.multiIndices) {
|
for (index in res.multiIndices) {
|
||||||
res[index] = this[index]
|
res[index] = this[index]
|
||||||
}
|
}
|
||||||
@ -296,7 +296,7 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>(
|
|||||||
|
|
||||||
@OptIn(UnsafeKMathAPI::class)
|
@OptIn(UnsafeKMathAPI::class)
|
||||||
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
|
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
|
||||||
Multik.ones<T, DN>(arg.shape.asArray(), type).apply { divAssign(arg.asMultik().array) }.wrap()
|
Multik.ones<T, DN>(arg.shape.asArray(), dataType).apply { divAssign(arg.asMultik().array) }.wrap()
|
||||||
|
|
||||||
override fun StructureND<T>.div(arg: T): MultikTensor<T> =
|
override fun StructureND<T>.div(arg: T): MultikTensor<T> =
|
||||||
asMultik().array.div(arg).wrap()
|
asMultik().array.div(arg).wrap()
|
||||||
|
@ -9,20 +9,21 @@ import space.kscience.attributes.*
|
|||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
import space.kscience.kmath.expressions.DifferentiableExpression
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
|
||||||
public class OptimizationValue<T>(public val value: T) : OptimizationFeature {
|
public class OptimizationValue<V>(type: SafeType<V>) : PolymorphicAttribute<V>(type)
|
||||||
override fun toString(): String = "Value($value)"
|
|
||||||
}
|
|
||||||
|
|
||||||
public enum class FunctionOptimizationTarget {
|
public enum class OptimizationDirection {
|
||||||
MAXIMIZE,
|
MAXIMIZE,
|
||||||
MINIMIZE
|
MINIMIZE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public object FunctionOptimizationTarget: OptimizationAttribute<OptimizationDirection>
|
||||||
|
|
||||||
public class FunctionOptimization<T>(
|
public class FunctionOptimization<T>(
|
||||||
override val attributes: Attributes,
|
|
||||||
public val expression: DifferentiableExpression<T>,
|
public val expression: DifferentiableExpression<T>,
|
||||||
|
override val attributes: Attributes,
|
||||||
) : OptimizationProblem<T> {
|
) : OptimizationProblem<T> {
|
||||||
|
|
||||||
|
override val type: SafeType<T> get() = expression.type
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
@ -47,36 +48,52 @@ public class FunctionOptimization<T>(
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun <T> FunctionOptimization(
|
||||||
|
expression: DifferentiableExpression<T>,
|
||||||
|
attributeBuilder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit,
|
||||||
|
): FunctionOptimization<T> = FunctionOptimization(expression, Attributes(attributeBuilder))
|
||||||
|
|
||||||
|
public class OptimizationPrior<T> :
|
||||||
public class OptimizationPrior<T>(type: SafeType<T>):
|
|
||||||
PolymorphicAttribute<DifferentiableExpression<T>>(safeTypeOf()),
|
PolymorphicAttribute<DifferentiableExpression<T>>(safeTypeOf()),
|
||||||
Attribute<DifferentiableExpression<T>>
|
Attribute<DifferentiableExpression<T>>
|
||||||
|
|
||||||
//public val <T> FunctionOptimization.Companion.Optimization get() =
|
public fun <T> FunctionOptimization<T>.withAttributes(
|
||||||
|
modifier: AttributesBuilder<FunctionOptimization<T>>.() -> Unit,
|
||||||
|
|
||||||
public fun <T> FunctionOptimization<T>.withFeatures(
|
|
||||||
vararg newFeature: OptimizationFeature,
|
|
||||||
): FunctionOptimization<T> = FunctionOptimization(
|
): FunctionOptimization<T> = FunctionOptimization(
|
||||||
attributes.with(*newFeature),
|
|
||||||
expression,
|
expression,
|
||||||
|
attributes.modify(modifier),
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optimizes differentiable expression using specific [optimizer] form given [startingPoint].
|
* Optimizes differentiable expression using specific [optimizer] form given [startingPoint].
|
||||||
*/
|
*/
|
||||||
public suspend fun <T : Any> DifferentiableExpression<T>.optimizeWith(
|
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
|
||||||
optimizer: Optimizer<T, FunctionOptimization<T>>,
|
optimizer: Optimizer<T, FunctionOptimization<T>>,
|
||||||
startingPoint: Map<Symbol, T>,
|
startingPoint: Map<Symbol, T>,
|
||||||
vararg features: OptimizationFeature,
|
modifier: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
|
||||||
): FunctionOptimization<T> {
|
): FunctionOptimization<T> {
|
||||||
val problem = FunctionOptimization<T>(FeatureSet.of(OptimizationStartPoint(startingPoint), *features), this)
|
val problem = FunctionOptimization(this){
|
||||||
|
startAt(startingPoint)
|
||||||
|
modifier()
|
||||||
|
}
|
||||||
return optimizer.optimize(problem)
|
return optimizer.optimize(problem)
|
||||||
}
|
}
|
||||||
|
|
||||||
public val <T> FunctionOptimization<T>.resultValueOrNull: T?
|
public val <T> FunctionOptimization<T>.resultValueOrNull: T?
|
||||||
get() = getFeature<OptimizationResult<T>>()?.point?.let { expression(it) }
|
get() = attributes[OptimizationResult<T>()]?.let { expression(it) }
|
||||||
|
|
||||||
public val <T> FunctionOptimization<T>.resultValue: T
|
public val <T> FunctionOptimization<T>.resultValue: T
|
||||||
get() = resultValueOrNull ?: error("Result is not present in $this")
|
get() = resultValueOrNull ?: error("Result is not present in $this")
|
||||||
|
|
||||||
|
|
||||||
|
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
|
||||||
|
optimizer: Optimizer<T, FunctionOptimization<T>>,
|
||||||
|
vararg startingPoint: Pair<Symbol, T>,
|
||||||
|
builder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
|
||||||
|
): FunctionOptimization<T> {
|
||||||
|
val problem = FunctionOptimization<T>(this) {
|
||||||
|
startAt(mapOf(*startingPoint))
|
||||||
|
builder()
|
||||||
|
}
|
||||||
|
return optimizer.optimize(problem)
|
||||||
|
}
|
||||||
|
@ -1,96 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2022 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.optimization
|
|
||||||
|
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.data.XYColumnarData
|
|
||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
|
||||||
import space.kscience.kmath.expressions.Symbol
|
|
||||||
import space.kscience.kmath.misc.FeatureSet
|
|
||||||
|
|
||||||
public abstract class OptimizationBuilder<T, R : OptimizationProblem<T>> {
|
|
||||||
public val features: MutableList<OptimizationFeature> = ArrayList()
|
|
||||||
|
|
||||||
public fun addFeature(feature: OptimizationFeature) {
|
|
||||||
features.add(feature)
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <reified T : OptimizationFeature> updateFeature(update: (T?) -> T) {
|
|
||||||
val existing = features.find { it.key == T::class } as? T
|
|
||||||
val new = update(existing)
|
|
||||||
if (existing != null) {
|
|
||||||
features.remove(existing)
|
|
||||||
}
|
|
||||||
addFeature(new)
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract fun build(): R
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T> OptimizationBuilder<T, *>.startAt(startingPoint: Map<Symbol, T>) {
|
|
||||||
addFeature(OptimizationStartPoint(startingPoint))
|
|
||||||
}
|
|
||||||
|
|
||||||
public class FunctionOptimizationBuilder<T>(
|
|
||||||
private val expression: DifferentiableExpression<T>,
|
|
||||||
) : OptimizationBuilder<T, FunctionOptimization<T>>() {
|
|
||||||
override fun build(): FunctionOptimization<T> = FunctionOptimization(FeatureSet.of(features), expression)
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun <T> FunctionOptimization(
|
|
||||||
expression: DifferentiableExpression<T>,
|
|
||||||
builder: FunctionOptimizationBuilder<T>.() -> Unit,
|
|
||||||
): FunctionOptimization<T> = FunctionOptimizationBuilder(expression).apply(builder).build()
|
|
||||||
|
|
||||||
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
|
|
||||||
optimizer: Optimizer<T, FunctionOptimization<T>>,
|
|
||||||
startingPoint: Map<Symbol, T>,
|
|
||||||
builder: FunctionOptimizationBuilder<T>.() -> Unit = {},
|
|
||||||
): FunctionOptimization<T> {
|
|
||||||
val problem = FunctionOptimization<T>(this) {
|
|
||||||
startAt(startingPoint)
|
|
||||||
builder()
|
|
||||||
}
|
|
||||||
return optimizer.optimize(problem)
|
|
||||||
}
|
|
||||||
|
|
||||||
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
|
|
||||||
optimizer: Optimizer<T, FunctionOptimization<T>>,
|
|
||||||
vararg startingPoint: Pair<Symbol, T>,
|
|
||||||
builder: FunctionOptimizationBuilder<T>.() -> Unit = {},
|
|
||||||
): FunctionOptimization<T> {
|
|
||||||
val problem = FunctionOptimization<T>(this) {
|
|
||||||
startAt(mapOf(*startingPoint))
|
|
||||||
builder()
|
|
||||||
}
|
|
||||||
return optimizer.optimize(problem)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public class XYOptimizationBuilder(
|
|
||||||
public val data: XYColumnarData<Double, Double, Double>,
|
|
||||||
public val model: DifferentiableExpression<Double>,
|
|
||||||
) : OptimizationBuilder<Double, XYFit>() {
|
|
||||||
|
|
||||||
public var pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY
|
|
||||||
public var pointWeight: PointWeight = PointWeight.byYSigma
|
|
||||||
|
|
||||||
override fun build(): XYFit = XYFit(
|
|
||||||
data,
|
|
||||||
model,
|
|
||||||
FeatureSet.of(features),
|
|
||||||
pointToCurveDistance,
|
|
||||||
pointWeight
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public fun XYOptimization(
|
|
||||||
data: XYColumnarData<Double, Double, Double>,
|
|
||||||
model: DifferentiableExpression<Double>,
|
|
||||||
builder: XYOptimizationBuilder.() -> Unit,
|
|
||||||
): XYFit = XYOptimizationBuilder(data, model).apply(builder).build()
|
|
@ -6,64 +6,53 @@
|
|||||||
package space.kscience.kmath.optimization
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
import space.kscience.attributes.*
|
import space.kscience.attributes.*
|
||||||
import space.kscience.kmath.expressions.DifferentiableExpression
|
|
||||||
import space.kscience.kmath.expressions.NamedMatrix
|
import space.kscience.kmath.expressions.NamedMatrix
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.misc.*
|
import space.kscience.kmath.misc.Loggable
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
public interface OptimizationAttribute<T>: Attribute<T>
|
public interface OptimizationAttribute<T> : Attribute<T>
|
||||||
|
|
||||||
public interface OptimizationProblem<T> : AttributeContainer
|
public interface OptimizationProblem<T> : AttributeContainer, WithType<T>
|
||||||
|
|
||||||
public inline fun <reified F : OptimizationFeature> OptimizationProblem<*>.getFeature(): F? = getFeature(F::class)
|
public class OptimizationStartPoint<T> : OptimizationAttribute<Map<Symbol, T>>,
|
||||||
|
PolymorphicAttribute<Map<Symbol, T>>(safeTypeOf())
|
||||||
public open class OptimizationStartPoint<T>(public val point: Map<Symbol, T>) : OptimizationFeature {
|
|
||||||
override fun toString(): String = "StartPoint($point)"
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Covariance matrix for
|
|
||||||
*/
|
|
||||||
public class OptimizationCovariance<T>(public val covariance: NamedMatrix<T>) : OptimizationFeature {
|
|
||||||
override fun toString(): String = "Covariance($covariance)"
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the starting point for optimization. Throws error if not defined.
|
* Get the starting point for optimization. Throws error if not defined.
|
||||||
*/
|
*/
|
||||||
public val <T> OptimizationProblem<T>.startPoint: Map<Symbol, T>
|
public val <T> OptimizationProblem<T>.startPoint: Map<Symbol, T>
|
||||||
get() = getFeature<OptimizationStartPoint<T>>()?.point
|
get() = attributes[OptimizationStartPoint()] ?: error("Starting point not defined in $this")
|
||||||
?: error("Starting point not defined in $this")
|
|
||||||
|
|
||||||
public open class OptimizationResult<T>(public val point: Map<Symbol, T>) : OptimizationFeature {
|
public fun <T> AttributesBuilder<OptimizationProblem<T>>.startAt(startingPoint: Map<Symbol, T>) {
|
||||||
override fun toString(): String = "Result($point)"
|
set(::OptimizationStartPoint, startingPoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
public val <T> OptimizationProblem<T>.resultPointOrNull: Map<Symbol, T>?
|
|
||||||
get() = getFeature<OptimizationResult<T>>()?.point
|
|
||||||
|
|
||||||
public val <T> OptimizationProblem<T>.resultPoint: Map<Symbol, T>
|
/**
|
||||||
get() = resultPointOrNull ?: error("Result is not present in $this")
|
* Covariance matrix for optimization
|
||||||
|
*/
|
||||||
|
public class OptimizationCovariance<T> : OptimizationAttribute<NamedMatrix<T>>,
|
||||||
|
PolymorphicAttribute<NamedMatrix<T>>(safeTypeOf())
|
||||||
|
|
||||||
public class OptimizationLog(private val loggable: Loggable) : Loggable by loggable, OptimizationFeature {
|
|
||||||
override fun toString(): String = "Log($loggable)"
|
public class OptimizationResult<T>() : OptimizationAttribute<Map<Symbol, T>>,
|
||||||
}
|
PolymorphicAttribute<Map<Symbol, T>>(safeTypeOf())
|
||||||
|
|
||||||
|
public val <T> OptimizationProblem<T>.resultOrNull: Map<Symbol, T>? get() = attributes[OptimizationResult()]
|
||||||
|
|
||||||
|
public val <T> OptimizationProblem<T>.result: Map<Symbol, T>
|
||||||
|
get() = resultOrNull ?: error("Result is not present in $this")
|
||||||
|
|
||||||
|
public object OptimizationLog : OptimizationAttribute<Loggable>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Free parameters of the optimization
|
* Free parameters of the optimization
|
||||||
*/
|
*/
|
||||||
public class OptimizationParameters(public val symbols: List<Symbol>) : OptimizationFeature {
|
public object OptimizationParameters : OptimizationAttribute<List<Symbol>>
|
||||||
public constructor(vararg symbols: Symbol) : this(listOf(*symbols))
|
|
||||||
|
|
||||||
override fun toString(): String = "Parameters($symbols)"
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maximum allowed number of iterations
|
* Maximum allowed number of iterations
|
||||||
*/
|
*/
|
||||||
public class OptimizationIterations(public val maxIterations: Int) : OptimizationFeature {
|
public object OptimizationIterations : OptimizationAttribute<Int>
|
||||||
override fun toString(): String = "Iterations($maxIterations)"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,13 +16,7 @@ import space.kscience.kmath.structures.Float64Buffer
|
|||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
|
|
||||||
|
|
||||||
public class QowRuns(public val runs: Int) : OptimizationFeature {
|
public object QowRuns: OptimizationAttribute<Int>
|
||||||
init {
|
|
||||||
require(runs >= 1) { "Number of runs must be more than zero" }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = "QowRuns(runs=$runs)"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -69,7 +63,7 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
val prior: DifferentiableExpression<Double>?
|
val prior: DifferentiableExpression<Double>?
|
||||||
get() = problem.getFeature<OptimizationPrior<Double>>()?.withDefaultArgs(allParameters)
|
get() = problem.attributes[OptimizationPrior<Double>()]?.withDefaultArgs(allParameters)
|
||||||
|
|
||||||
override fun toString(): String = freeParameters.toString()
|
override fun toString(): String = freeParameters.toString()
|
||||||
}
|
}
|
||||||
@ -176,7 +170,7 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
|
|||||||
fast: Boolean = false,
|
fast: Boolean = false,
|
||||||
): QoWeight {
|
): QoWeight {
|
||||||
|
|
||||||
val logger = problem.getFeature<OptimizationLog>()
|
val logger = problem.attributes[OptimizationLog]
|
||||||
|
|
||||||
var dis: Double //discrepancy value
|
var dis: Double //discrepancy value
|
||||||
|
|
||||||
@ -231,7 +225,7 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun QoWeight.covariance(): NamedMatrix<Double> {
|
private fun QoWeight.covariance(): NamedMatrix<Double> {
|
||||||
val logger = problem.getFeature<OptimizationLog>()
|
val logger = problem.attributes[OptimizationLog]
|
||||||
|
|
||||||
logger?.log {
|
logger?.log {
|
||||||
"""
|
"""
|
||||||
@ -257,11 +251,11 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun optimize(problem: XYFit): XYFit {
|
override suspend fun optimize(problem: XYFit): XYFit {
|
||||||
val qowRuns = problem.getFeature<QowRuns>()?.runs ?: 2
|
val qowRuns = problem.attributes[QowRuns] ?: 2
|
||||||
val iterations = problem.getFeature<OptimizationIterations>()?.maxIterations ?: 50
|
val iterations = problem.attributes[OptimizationIterations] ?: 50
|
||||||
|
|
||||||
val freeParameters: Map<Symbol, Double> = problem.getFeature<OptimizationParameters>()?.let { op ->
|
val freeParameters: Map<Symbol, Double> = problem.attributes[OptimizationParameters]?.let { symbols ->
|
||||||
problem.startPoint.filterKeys { it in op.symbols }
|
problem.startPoint.filterKeys { it in symbols }
|
||||||
} ?: problem.startPoint
|
} ?: problem.startPoint
|
||||||
|
|
||||||
var qow = QoWeight(problem, freeParameters)
|
var qow = QoWeight(problem, freeParameters)
|
||||||
@ -271,6 +265,9 @@ public object QowOptimizer : Optimizer<Double, XYFit> {
|
|||||||
res = qow.newtonianRun(maxSteps = iterations)
|
res = qow.newtonianRun(maxSteps = iterations)
|
||||||
}
|
}
|
||||||
val covariance = res.covariance()
|
val covariance = res.covariance()
|
||||||
return res.problem.withFeature(OptimizationResult(res.freeParameters), OptimizationCovariance(covariance))
|
return res.problem.withAttributes {
|
||||||
|
set(OptimizationResult(), res.freeParameters)
|
||||||
|
set(OptimizationCovariance(), covariance)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -2,37 +2,41 @@
|
|||||||
* Copyright 2018-2022 KMath contributors.
|
* Copyright 2018-2022 KMath contributors.
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
@file:OptIn(UnstableKMathAPI::class)
|
@file:OptIn(UnstableKMathAPI::class, UnstableKMathAPI::class)
|
||||||
|
|
||||||
package space.kscience.kmath.optimization
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
|
import space.kscience.attributes.*
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.data.XYColumnarData
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
import space.kscience.kmath.data.indices
|
import space.kscience.kmath.data.indices
|
||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.FeatureSet
|
|
||||||
import space.kscience.kmath.misc.Loggable
|
import space.kscience.kmath.misc.Loggable
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
|
import space.kscience.kmath.operations.Float64Field
|
||||||
import space.kscience.kmath.operations.bindSymbol
|
import space.kscience.kmath.operations.bindSymbol
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Specify the way to compute distance from point to the curve as DifferentiableExpression
|
* Specify the way to compute distance from point to the curve as DifferentiableExpression
|
||||||
*/
|
*/
|
||||||
public interface PointToCurveDistance : OptimizationFeature {
|
public interface PointToCurveDistance {
|
||||||
public fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double>
|
public fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double>
|
||||||
|
|
||||||
public companion object {
|
public companion object : OptimizationAttribute<PointToCurveDistance> {
|
||||||
public val byY: PointToCurveDistance = object : PointToCurveDistance {
|
public val byY: PointToCurveDistance = object : PointToCurveDistance {
|
||||||
override fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double> {
|
override fun distance(problem: XYFit, index: Int): DifferentiableExpression<Double> {
|
||||||
val x = problem.data.x[index]
|
val x = problem.data.x[index]
|
||||||
val y = problem.data.y[index]
|
val y = problem.data.y[index]
|
||||||
|
|
||||||
return object : DifferentiableExpression<Double> {
|
return object : DifferentiableExpression<Double> {
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
|
|
||||||
override fun derivativeOrNull(
|
override fun derivativeOrNull(
|
||||||
symbols: List<Symbol>,
|
symbols: List<Symbol>,
|
||||||
): Expression<Double>? = problem.model.derivativeOrNull(symbols)?.let { derivExpression ->
|
): Expression<Double>? = problem.model.derivativeOrNull(symbols)?.let { derivExpression ->
|
||||||
Expression { arguments ->
|
Expression(DoubleField.type) { arguments ->
|
||||||
derivExpression.invoke(arguments + (Symbol.x to x))
|
derivExpression.invoke(arguments + (Symbol.x to x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -51,18 +55,21 @@ public interface PointToCurveDistance : OptimizationFeature {
|
|||||||
* Compute a wight of the point. The more the weight, the more impact this point will have on the fit.
|
* Compute a wight of the point. The more the weight, the more impact this point will have on the fit.
|
||||||
* By default, uses Dispersion^-1
|
* By default, uses Dispersion^-1
|
||||||
*/
|
*/
|
||||||
public interface PointWeight : OptimizationFeature {
|
public interface PointWeight {
|
||||||
public fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double>
|
public fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double>
|
||||||
|
|
||||||
public companion object {
|
public companion object : OptimizationAttribute<PointWeight> {
|
||||||
public fun bySigma(sigmaSymbol: Symbol): PointWeight = object : PointWeight {
|
public fun bySigma(sigmaSymbol: Symbol): PointWeight = object : PointWeight {
|
||||||
override fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double> =
|
override fun weight(problem: XYFit, index: Int): DifferentiableExpression<Double> =
|
||||||
object : DifferentiableExpression<Double> {
|
object : DifferentiableExpression<Double> {
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
|
|
||||||
override fun invoke(arguments: Map<Symbol, Double>): Double {
|
override fun invoke(arguments: Map<Symbol, Double>): Double {
|
||||||
return problem.data[sigmaSymbol]?.get(index)?.pow(-2) ?: 1.0
|
return problem.data[sigmaSymbol]?.get(index)?.pow(-2) ?: 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { 0.0 }
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> =
|
||||||
|
Expression(DoubleField.type) { 0.0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toString(): String = "PointWeightBySigma($sigmaSymbol)"
|
override fun toString(): String = "PointWeightBySigma($sigmaSymbol)"
|
||||||
@ -79,41 +86,52 @@ public interface PointWeight : OptimizationFeature {
|
|||||||
public class XYFit(
|
public class XYFit(
|
||||||
public val data: XYColumnarData<Double, Double, Double>,
|
public val data: XYColumnarData<Double, Double, Double>,
|
||||||
public val model: DifferentiableExpression<Double>,
|
public val model: DifferentiableExpression<Double>,
|
||||||
override val attributes: FeatureSet<OptimizationFeature>,
|
override val attributes: Attributes,
|
||||||
internal val pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
|
internal val pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
|
||||||
internal val pointWeight: PointWeight = PointWeight.byYSigma,
|
internal val pointWeight: PointWeight = PointWeight.byYSigma,
|
||||||
public val xSymbol: Symbol = Symbol.x,
|
public val xSymbol: Symbol = Symbol.x,
|
||||||
) : OptimizationProblem<Double> {
|
) : OptimizationProblem<Double> {
|
||||||
|
|
||||||
|
override val type: SafeType<Double> get() = Float64Field.type
|
||||||
|
|
||||||
public fun distance(index: Int): DifferentiableExpression<Double> = pointToCurveDistance.distance(this, index)
|
public fun distance(index: Int): DifferentiableExpression<Double> = pointToCurveDistance.distance(this, index)
|
||||||
|
|
||||||
public fun weight(index: Int): DifferentiableExpression<Double> = pointWeight.weight(this, index)
|
public fun weight(index: Int): DifferentiableExpression<Double> = pointWeight.weight(this, index)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun XYFit.withFeature(vararg features: OptimizationFeature): XYFit {
|
|
||||||
return XYFit(data, model, this.attributes.with(*features), pointToCurveDistance, pointWeight)
|
public fun XYOptimization(
|
||||||
}
|
data: XYColumnarData<Double, Double, Double>,
|
||||||
|
model: DifferentiableExpression<Double>,
|
||||||
|
builder: AttributesBuilder<XYFit>.() -> Unit,
|
||||||
|
): XYFit = XYFit(data, model, Attributes(builder))
|
||||||
|
|
||||||
|
public fun XYFit.withAttributes(
|
||||||
|
modifier: AttributesBuilder<XYFit>.() -> Unit,
|
||||||
|
): XYFit = XYFit(data, model, attributes.modify(modifier), pointToCurveDistance, pointWeight, xSymbol)
|
||||||
|
|
||||||
public suspend fun XYColumnarData<Double, Double, Double>.fitWith(
|
public suspend fun XYColumnarData<Double, Double, Double>.fitWith(
|
||||||
optimizer: Optimizer<Double, XYFit>,
|
optimizer: Optimizer<Double, XYFit>,
|
||||||
modelExpression: DifferentiableExpression<Double>,
|
modelExpression: DifferentiableExpression<Double>,
|
||||||
startingPoint: Map<Symbol, Double>,
|
startingPoint: Map<Symbol, Double>,
|
||||||
vararg features: OptimizationFeature = emptyArray(),
|
attributes: Attributes = Attributes.EMPTY,
|
||||||
xSymbol: Symbol = Symbol.x,
|
xSymbol: Symbol = Symbol.x,
|
||||||
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
|
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
|
||||||
pointWeight: PointWeight = PointWeight.byYSigma,
|
pointWeight: PointWeight = PointWeight.byYSigma,
|
||||||
): XYFit {
|
): XYFit {
|
||||||
var actualFeatures = FeatureSet.of(*features, OptimizationStartPoint(startingPoint))
|
|
||||||
|
|
||||||
if (actualFeatures.getFeature<OptimizationLog>() == null) {
|
|
||||||
actualFeatures = actualFeatures.with(OptimizationLog(Loggable.console))
|
|
||||||
}
|
|
||||||
val problem = XYFit(
|
val problem = XYFit(
|
||||||
this,
|
this,
|
||||||
modelExpression,
|
modelExpression,
|
||||||
actualFeatures,
|
attributes.modify<XYFit> {
|
||||||
|
set(::OptimizationStartPoint, startingPoint)
|
||||||
|
if (!hasAny<OptimizationLog>()) {
|
||||||
|
set(OptimizationLog, Loggable.console)
|
||||||
|
}
|
||||||
|
},
|
||||||
pointToCurveDistance,
|
pointToCurveDistance,
|
||||||
pointWeight,
|
pointWeight,
|
||||||
xSymbol
|
xSymbol,
|
||||||
)
|
)
|
||||||
return optimizer.optimize(problem)
|
return optimizer.optimize(problem)
|
||||||
}
|
}
|
||||||
@ -125,7 +143,7 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
|
|||||||
optimizer: Optimizer<Double, XYFit>,
|
optimizer: Optimizer<Double, XYFit>,
|
||||||
processor: AutoDiffProcessor<Double, I, A>,
|
processor: AutoDiffProcessor<Double, I, A>,
|
||||||
startingPoint: Map<Symbol, Double>,
|
startingPoint: Map<Symbol, Double>,
|
||||||
vararg features: OptimizationFeature = emptyArray(),
|
attributes: Attributes = Attributes.EMPTY,
|
||||||
xSymbol: Symbol = Symbol.x,
|
xSymbol: Symbol = Symbol.x,
|
||||||
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
|
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
|
||||||
pointWeight: PointWeight = PointWeight.byYSigma,
|
pointWeight: PointWeight = PointWeight.byYSigma,
|
||||||
@ -140,7 +158,7 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
|
|||||||
optimizer = optimizer,
|
optimizer = optimizer,
|
||||||
modelExpression = modelExpression,
|
modelExpression = modelExpression,
|
||||||
startingPoint = startingPoint,
|
startingPoint = startingPoint,
|
||||||
features = features,
|
attributes = attributes,
|
||||||
xSymbol = xSymbol,
|
xSymbol = xSymbol,
|
||||||
pointToCurveDistance = pointToCurveDistance,
|
pointToCurveDistance = pointToCurveDistance,
|
||||||
pointWeight = pointWeight
|
pointWeight = pointWeight
|
||||||
@ -152,7 +170,7 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
|
|||||||
*/
|
*/
|
||||||
public val XYFit.chiSquaredOrNull: Double?
|
public val XYFit.chiSquaredOrNull: Double?
|
||||||
get() {
|
get() {
|
||||||
val result = startPoint + (resultPointOrNull ?: return null)
|
val result = startPoint + (resultOrNull ?: return null)
|
||||||
|
|
||||||
return data.indices.sumOf { index ->
|
return data.indices.sumOf { index ->
|
||||||
|
|
||||||
@ -167,4 +185,4 @@ public val XYFit.chiSquaredOrNull: Double?
|
|||||||
}
|
}
|
||||||
|
|
||||||
public val XYFit.dof: Int
|
public val XYFit.dof: Int
|
||||||
get() = data.size - (getFeature<OptimizationParameters>()?.symbols?.size ?: startPoint.size)
|
get() = data.size - (attributes[OptimizationParameters]?.size ?: startPoint.size)
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.optimization
|
package space.kscience.kmath.optimization
|
||||||
|
|
||||||
|
import space.kscience.attributes.AttributesBuilder
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.data.XYColumnarData
|
import space.kscience.kmath.data.XYColumnarData
|
||||||
import space.kscience.kmath.data.indices
|
import space.kscience.kmath.data.indices
|
||||||
@ -12,6 +14,7 @@ import space.kscience.kmath.expressions.DifferentiableExpression
|
|||||||
import space.kscience.kmath.expressions.Expression
|
import space.kscience.kmath.expressions.Expression
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.expressions.derivative
|
import space.kscience.kmath.expressions.derivative
|
||||||
|
import space.kscience.kmath.operations.Float64Field
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.ln
|
import kotlin.math.ln
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
@ -22,7 +25,9 @@ private val oneOver2Pi = 1.0 / sqrt(2 * PI)
|
|||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
internal fun XYFit.logLikelihood(): DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
|
internal fun XYFit.logLikelihood(): DifferentiableExpression<Double> = object : DifferentiableExpression<Double> {
|
||||||
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression { arguments ->
|
override val type: SafeType<Double> get() = Float64Field.type
|
||||||
|
|
||||||
|
override fun derivativeOrNull(symbols: List<Symbol>): Expression<Double> = Expression(type) { arguments ->
|
||||||
data.indices.sumOf { index ->
|
data.indices.sumOf { index ->
|
||||||
val d = distance(index)(arguments)
|
val d = distance(index)(arguments)
|
||||||
val weight = weight(index)(arguments)
|
val weight = weight(index)(arguments)
|
||||||
@ -53,14 +58,18 @@ internal fun XYFit.logLikelihood(): DifferentiableExpression<Double> = object :
|
|||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYFit): XYFit {
|
public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(problem: XYFit): XYFit {
|
||||||
val functionOptimization = FunctionOptimization(problem.attributes, problem.logLikelihood())
|
val functionOptimization = FunctionOptimization(problem.logLikelihood(), problem.attributes)
|
||||||
val result = optimize(functionOptimization.withFeatures(FunctionOptimizationTarget.MAXIMIZE))
|
val result = optimize(
|
||||||
return XYFit(problem.data, problem.model, result.attributes)
|
functionOptimization.withAttributes {
|
||||||
|
FunctionOptimizationTarget(OptimizationDirection.MAXIMIZE)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return XYFit(problem.data,problem.model, result.attributes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(
|
public suspend fun Optimizer<Double, FunctionOptimization<Double>>.maximumLogLikelihood(
|
||||||
data: XYColumnarData<Double, Double, Double>,
|
data: XYColumnarData<Double, Double, Double>,
|
||||||
model: DifferentiableExpression<Double>,
|
model: DifferentiableExpression<Double>,
|
||||||
builder: XYOptimizationBuilder.() -> Unit,
|
builder: AttributesBuilder<XYFit>.() -> Unit,
|
||||||
): XYFit = maximumLogLikelihood(XYOptimization(data, model, builder))
|
): XYFit = maximumLogLikelihood(XYOptimization(data, model, builder))
|
||||||
|
@ -5,10 +5,12 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.UnstableKMathAPI
|
import space.kscience.kmath.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.MutableStructureNDOfDouble
|
import space.kscience.kmath.nd.MutableStructureNDOfDouble
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
import space.kscience.kmath.tensors.core.internal.toPrettyString
|
import space.kscience.kmath.tensors.core.internal.toPrettyString
|
||||||
|
|
||||||
@ -88,6 +90,8 @@ public open class DoubleTensor(
|
|||||||
final override val source: OffsetDoubleBuffer,
|
final override val source: OffsetDoubleBuffer,
|
||||||
) : BufferedTensor<Double>(shape), MutableStructureNDOfDouble {
|
) : BufferedTensor<Double>(shape), MutableStructureNDOfDouble {
|
||||||
|
|
||||||
|
override val type: SafeType<Double> get() = DoubleField.type
|
||||||
|
|
||||||
init {
|
init {
|
||||||
require(linearSize == source.size) { "Source buffer size must be equal tensor size" }
|
require(linearSize == source.size) { "Source buffer size must be equal tensor size" }
|
||||||
}
|
}
|
||||||
|
@ -5,8 +5,10 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
|
import space.kscience.attributes.SafeType
|
||||||
import space.kscience.kmath.PerformancePitfall
|
import space.kscience.kmath.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.ShapeND
|
import space.kscience.kmath.nd.ShapeND
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -24,6 +26,8 @@ public class OffsetIntBuffer(
|
|||||||
require(offset + size <= source.size) { "Maximum index must be inside source dimension" }
|
require(offset + size <= source.size) { "Maximum index must be inside source dimension" }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val type: SafeType<Int> get() = IntRing.type
|
||||||
|
|
||||||
override fun set(index: Int, value: Int) {
|
override fun set(index: Int, value: Int) {
|
||||||
require(index in 0 until size) { "Index must be in [0, size)" }
|
require(index in 0 until size) { "Index must be in [0, size)" }
|
||||||
source[index + offset] = value
|
source[index + offset] = value
|
||||||
@ -83,6 +87,8 @@ public class IntTensor(
|
|||||||
require(linearSize == source.size) { "Source buffer size must be equal tensor size" }
|
require(linearSize == source.size) { "Source buffer size must be equal tensor size" }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val type: SafeType<Int> get() = IntRing.type
|
||||||
|
|
||||||
public constructor(shape: ShapeND, buffer: Int32Buffer) : this(shape, OffsetIntBuffer(buffer, 0, buffer.size))
|
public constructor(shape: ShapeND, buffer: Int32Buffer) : this(shape, OffsetIntBuffer(buffer, 0, buffer.size))
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
|
Loading…
Reference in New Issue
Block a user