[WIP] refactor features to attributes

This commit is contained in:
Alexander Nozik 2023-09-13 09:00:56 +03:00
parent 9da14089e0
commit dd3d38490a
13 changed files with 57 additions and 42 deletions

View File

@ -7,8 +7,10 @@ package space.kscience.attributes
/** /**
* A safe builder for [Attributes] * A safe builder for [Attributes]
*
* @param O type marker of an owner object, for which these attributes are made
*/ */
public class AttributesBuilder internal constructor(private val map: MutableMap<Attribute<*>, Any>) { public class TypedAttributesBuilder<in O> internal constructor(private val map: MutableMap<Attribute<*>, Any>) {
public constructor() : this(mutableMapOf()) public constructor() : this(mutableMapOf())
@ -47,6 +49,8 @@ public class AttributesBuilder internal constructor(private val map: MutableMap<
public fun build(): Attributes = Attributes(map) public fun build(): Attributes = Attributes(map)
} }
public typealias AttributesBuilder = TypedAttributesBuilder<Any?>
public fun AttributesBuilder( public fun AttributesBuilder(
attributes: Attributes, attributes: Attributes,
): AttributesBuilder = AttributesBuilder(attributes.content.toMutableMap()) ): AttributesBuilder = AttributesBuilder(attributes.content.toMutableMap())

View File

@ -6,7 +6,7 @@ kotlin.code.style=official
kotlin.mpp.stability.nowarn=true kotlin.mpp.stability.nowarn=true
kotlin.native.ignoreDisabledTargets=true kotlin.native.ignoreDisabledTargets=true
toolsVersion=0.14.9-kotlin-1.8.20 toolsVersion=0.14.9-kotlin-1.9.0
org.gradle.parallel=true org.gradle.parallel=true
org.gradle.workers.max=4 org.gradle.workers.max=4

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -22,7 +22,7 @@ public class CMGaussRuleIntegrator(
val integrator: GaussIntegrator = getIntegrator(range) val integrator: GaussIntegrator = getIntegrator(range)
//TODO check performance //TODO check performance
val res: Double = integrator.integrate(integrand.function) val res: Double = integrator.integrate(integrand.function)
return integrand.modify { return integrand.withAttributes {
IntegrandValue(res) IntegrandValue(res)
IntegrandCallsPerformed(integrand.calls + numpoints) IntegrandCallsPerformed(integrand.calls + numpoints)
} }

View File

@ -26,7 +26,7 @@ public class CMIntegrator(
val range = integrand[IntegrationRange] ?: error("Integration range is not provided") val range = integrand[IntegrationRange] ?: error("Integration range is not provided")
val res = integrator.integrate(remainingCalls, integrand.function, range.start, range.endInclusive) val res = integrator.integrate(remainingCalls, integrand.function, range.start, range.endInclusive)
return integrand.modify { return integrand.withAttributes {
value(res) value(res)
IntegrandAbsoluteAccuracy(integrator.absoluteAccuracy) IntegrandAbsoluteAccuracy(integrator.absoluteAccuracy)
IntegrandRelativeAccuracy(integrator.relativeAccuracy) IntegrandRelativeAccuracy(integrator.relativeAccuracy)

View File

@ -4,7 +4,7 @@
*/ */
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.attributes.AttributesBuilder import space.kscience.attributes.TypedAttributesBuilder
import space.kscience.kmath.UnstableKMathAPI import space.kscience.kmath.UnstableKMathAPI
import space.kscience.kmath.operations.Field import space.kscience.kmath.operations.Field
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -57,7 +57,7 @@ public class GaussIntegrator<T : Any>(
override fun process(integrand: UnivariateIntegrand<T>): UnivariateIntegrand<T> = with(algebra) { override fun process(integrand: UnivariateIntegrand<T>): UnivariateIntegrand<T> = with(algebra) {
val f = integrand.function val f = integrand.function
val (points, weights) = buildRule(integrand) val (points, weights) = buildRule(integrand)
var res = zero var res: T = zero
var c = zero var c = zero
for (i in points.indices) { for (i in points.indices) {
val x = points[i] val x = points[i]
@ -67,7 +67,7 @@ public class GaussIntegrator<T : Any>(
c = t - res - y c = t - res - y
res = t res = t
} }
return integrand.modify { return integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + points.size) IntegrandCallsPerformed(integrand.calls + points.size)
} }
@ -88,12 +88,12 @@ public val <T : Any> Field<T>.gaussIntegrator: GaussIntegrator<T> get() = GaussI
* Integrate using [intervals] segments with Gauss-Legendre rule of [order] order. * Integrate using [intervals] segments with Gauss-Legendre rule of [order] order.
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun <T : Any> GaussIntegrator<T>.integrate( public inline fun <reified T : Any> GaussIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
order: Int = 10, order: Int = 10,
intervals: Int = 10, intervals: Int = 10,
attributesBuilder: AttributesBuilder.() -> Unit, attributesBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> { ): UnivariateIntegrand<T> {
require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" } require(range.endInclusive > range.start) { "The range upper bound should be higher than lower bound" }
require(order > 1) { "The order of polynomial must be more than 1" } require(order > 1) { "The order of polynomial must be more than 1" }
@ -103,7 +103,7 @@ public fun <T : Any> GaussIntegrator<T>.integrate(
(0 until intervals).map { i -> (range.start + rangeSize * i)..(range.start + rangeSize * (i + 1)) to order } (0 until intervals).map { i -> (range.start + rangeSize * i)..(range.start + rangeSize * (i + 1)) to order }
) )
return process( return process(
UnivariateIntegrand( UnivariateIntegrand<T>(
attributeBuilder = { attributeBuilder = {
IntegrationRange(range) IntegrationRange(range)
GaussIntegratorRuleFactory(GaussLegendreRuleFactory) GaussIntegratorRuleFactory(GaussLegendreRuleFactory)

View File

@ -5,10 +5,7 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.attributes.Attribute import space.kscience.attributes.*
import space.kscience.attributes.AttributeContainer
import space.kscience.attributes.AttributesBuilder
import space.kscience.attributes.SafeType
public interface IntegrandAttribute<T> : Attribute<T> public interface IntegrandAttribute<T> : Attribute<T>
@ -16,9 +13,10 @@ public interface Integrand<T> : AttributeContainer {
public val type: SafeType<T> public val type: SafeType<T>
public fun modify(block: AttributesBuilder.() -> Unit): Integrand<T> /**
* Create a copy of this integrand with a new set of attributes
public fun <A : Any> withAttribute(attribute: Attribute<A>, value: A): Integrand<T> */
public fun withAttributes(attributes: Attributes): Integrand<T>
public companion object public companion object
} }
@ -32,7 +30,7 @@ public sealed class IntegrandValue<T> private constructor(): IntegrandAttribute<
} }
} }
public fun <T> AttributesBuilder.value(value: T) { public fun <T> TypedAttributesBuilder<Integrand<T>>.value(value: T) {
IntegrandValue.forType<T>().invoke(value) IntegrandValue.forType<T>().invoke(value)
} }

View File

@ -14,14 +14,21 @@ public class MultivariateIntegrand<T>(
public val function: (Point<T>) -> T, public val function: (Point<T>) -> T,
) : Integrand<T> { ) : Integrand<T> {
override fun modify(block: AttributesBuilder.() -> Unit): MultivariateIntegrand<T> = override fun withAttributes(attributes: Attributes): MultivariateIntegrand<T> =
MultivariateIntegrand(type, attributes.modify(block), function) MultivariateIntegrand(type, attributes, function)
override fun <A : Any> withAttribute(attribute: Attribute<A>, value: A): MultivariateIntegrand<T> =
MultivariateIntegrand(type, attributes.withAttribute(attribute, value), function)
} }
public fun <T, A : Any> MultivariateIntegrand<T>.withAttribute(
attribute: Attribute<A>,
value: A,
): MultivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
public fun <T> MultivariateIntegrand<T>.withAttributes(
block: TypedAttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
): MultivariateIntegrand<T> = withAttributes(attributes.modify(block))
public inline fun <reified T : Any> MultivariateIntegrand( public inline fun <reified T : Any> MultivariateIntegrand(
attributeBuilder: AttributesBuilder.() -> Unit, attributeBuilder: TypedAttributesBuilder<MultivariateIntegrand<T>>.() -> Unit,
noinline function: (Point<T>) -> T, noinline function: (Point<T>) -> T,
): MultivariateIntegrand<T> = MultivariateIntegrand(safeTypeOf<T>(), Attributes(attributeBuilder), function) ): MultivariateIntegrand<T> = MultivariateIntegrand(safeTypeOf<T>(), Attributes(attributeBuilder), function)

View File

@ -48,7 +48,7 @@ public class SimpsonIntegrator<T : Any>(
val ranges = integrand[UnivariateIntegrandRanges] val ranges = integrand[UnivariateIntegrandRanges]
return if (ranges != null) { return if (ranges != null) {
val res = algebra.sum(ranges.ranges.map { integrateRange(integrand, it.first, it.second) }) val res = algebra.sum(ranges.ranges.map { integrateRange(integrand, it.first, it.second) })
integrand.modify { integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + ranges.ranges.sumOf { it.second }) IntegrandCallsPerformed(integrand.calls + ranges.ranges.sumOf { it.second })
} }
@ -57,7 +57,7 @@ public class SimpsonIntegrator<T : Any>(
require(numPoints >= 4) { "Simpson integrator requires at least 4 nodes" } require(numPoints >= 4) { "Simpson integrator requires at least 4 nodes" }
val range = integrand[IntegrationRange] ?: 0.0..1.0 val range = integrand[IntegrationRange] ?: 0.0..1.0
val res = integrateRange(integrand, range, numPoints) val res = integrateRange(integrand, range, numPoints)
integrand.modify { integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + numPoints) IntegrandCallsPerformed(integrand.calls + numPoints)
} }
@ -100,7 +100,7 @@ public object DoubleSimpsonIntegrator : UnivariateIntegrator<Double> {
val ranges = integrand[UnivariateIntegrandRanges] val ranges = integrand[UnivariateIntegrandRanges]
return if (ranges != null) { return if (ranges != null) {
val res = ranges.ranges.sumOf { integrateRange(integrand, it.first, it.second) } val res = ranges.ranges.sumOf { integrateRange(integrand, it.first, it.second) }
integrand.modify { integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + ranges.ranges.sumOf { it.second }) IntegrandCallsPerformed(integrand.calls + ranges.ranges.sumOf { it.second })
} }
@ -109,7 +109,7 @@ public object DoubleSimpsonIntegrator : UnivariateIntegrator<Double> {
require(numPoints >= 4) { "Simpson integrator requires at least 4 nodes" } require(numPoints >= 4) { "Simpson integrator requires at least 4 nodes" }
val range = integrand[IntegrationRange] ?: 0.0..1.0 val range = integrand[IntegrationRange] ?: 0.0..1.0
val res = integrateRange(integrand, range, numPoints) val res = integrateRange(integrand, range, numPoints)
integrand.modify { integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + numPoints) IntegrandCallsPerformed(integrand.calls + numPoints)
} }

View File

@ -71,7 +71,7 @@ public class SplineIntegrator<T : Comparable<T>>(
values values
) )
val res = polynomials.integrate(algebra, number(range.start)..number(range.endInclusive)) val res = polynomials.integrate(algebra, number(range.start)..number(range.endInclusive))
integrand.modify { integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + nodes.size) IntegrandCallsPerformed(integrand.calls + nodes.size)
} }
@ -99,7 +99,7 @@ public object DoubleSplineIntegrator : UnivariateIntegrator<Double> {
val values = nodes.mapToBuffer(::Float64Buffer) { integrand.function(it) } val values = nodes.mapToBuffer(::Float64Buffer) { integrand.function(it) }
val polynomials = interpolator.interpolatePolynomials(nodes, values) val polynomials = interpolator.interpolatePolynomials(nodes, values)
val res = polynomials.integrate(Float64Field, range) val res = polynomials.integrate(Float64Field, range)
return integrand.modify { return integrand.withAttributes {
value(res) value(res)
IntegrandCallsPerformed(integrand.calls + nodes.size) IntegrandCallsPerformed(integrand.calls + nodes.size)
} }

View File

@ -16,15 +16,21 @@ public class UnivariateIntegrand<T>(
public val function: (Double) -> T, public val function: (Double) -> T,
) : Integrand<T> { ) : Integrand<T> {
override fun <A : Any> withAttribute(attribute: Attribute<A>, value: A): UnivariateIntegrand<T> = override fun withAttributes(attributes: Attributes): UnivariateIntegrand<T> =
UnivariateIntegrand(type, attributes.withAttribute(attribute, value), function) UnivariateIntegrand(type, attributes, function)
override fun modify(block: AttributesBuilder.() -> Unit): UnivariateIntegrand<T> =
UnivariateIntegrand(type, attributes.modify(block), function)
} }
public fun <T, A : Any> UnivariateIntegrand<T>.withAttribute(
attribute: Attribute<A>,
value: A,
): UnivariateIntegrand<T> = withAttributes(attributes.withAttribute(attribute, value))
public fun <T> UnivariateIntegrand<T>.withAttributes(
block: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
): UnivariateIntegrand<T> = withAttributes(attributes.modify(block))
public inline fun <reified T : Any> UnivariateIntegrand( public inline fun <reified T : Any> UnivariateIntegrand(
attributeBuilder: AttributesBuilder.() -> Unit, attributeBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function) ): UnivariateIntegrand<T> = UnivariateIntegrand(safeTypeOf(), Attributes(attributeBuilder), function)
@ -52,7 +58,7 @@ public class UnivariateIntegrandRanges(public val ranges: List<Pair<ClosedRange<
public object UnivariateIntegrationNodes : IntegrandAttribute<Buffer<Double>> public object UnivariateIntegrationNodes : IntegrandAttribute<Buffer<Double>>
public fun AttributesBuilder.integrationNodes(vararg nodes: Double) { public fun TypedAttributesBuilder<UnivariateIntegrand<*>>.integrationNodes(vararg nodes: Double) {
UnivariateIntegrationNodes(Float64Buffer(nodes)) UnivariateIntegrationNodes(Float64Buffer(nodes))
} }
@ -62,7 +68,7 @@ public fun AttributesBuilder.integrationNodes(vararg nodes: Double) {
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate( public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
attributesBuilder: AttributesBuilder.() -> Unit, attributesBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit,
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> = process(UnivariateIntegrand(attributesBuilder, function)) ): UnivariateIntegrand<T> = process(UnivariateIntegrand(attributesBuilder, function))
@ -73,7 +79,7 @@ public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
@UnstableKMathAPI @UnstableKMathAPI
public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate( public inline fun <reified T : Any> UnivariateIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
attributeBuilder: AttributesBuilder.() -> Unit = {}, attributeBuilder: TypedAttributesBuilder<UnivariateIntegrand<T>>.() -> Unit = {},
noinline function: (Double) -> T, noinline function: (Double) -> T,
): UnivariateIntegrand<T> { ): UnivariateIntegrand<T> {