Refactor integrator API.

This commit is contained in:
Alexander Nozik 2021-05-16 13:59:37 +03:00
parent 4964fb2642
commit 6f39b38a72
6 changed files with 28 additions and 34 deletions

View File

@ -5,8 +5,6 @@
package space.kscience.kmath.functions package space.kscience.kmath.functions
import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.value
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import kotlin.math.pow import kotlin.math.pow
@ -18,5 +16,5 @@ fun main() {
val result = DoubleField.integrate(0.0..10.0, function = function) val result = DoubleField.integrate(0.0..10.0, function = function)
//the value is nullable because in some cases the integration could not succeed //the value is nullable because in some cases the integration could not succeed
println(result.value) println(result.valueOrNull)
} }

View File

@ -5,8 +5,6 @@
package space.kscience.kmath.functions package space.kscience.kmath.functions
import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.value
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.nd import space.kscience.kmath.nd.nd
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
@ -27,6 +25,6 @@ fun main(): Unit = DoubleField {
val result = integrate(0.0..10.0, function = function) val result = integrate(0.0..10.0, function = function)
//the value is nullable because in some cases the integration could not succeed //the value is nullable because in some cases the integration could not succeed
println(result.value) println(result.valueOrNull)
} }
} }

View File

@ -78,6 +78,6 @@ public class GaussRuleIntegrator(
function: (Double) -> Double, function: (Double) -> Double,
): Double = GaussRuleIntegrator(numPoints, type).integrate( ): Double = GaussRuleIntegrator(numPoints, type).integrate(
UnivariateIntegrand(function, IntegrationRange(range)) UnivariateIntegrand(function, IntegrationRange(range))
).value!! ).valueOrNull!!
} }
} }

View File

@ -66,31 +66,25 @@ public class GaussIntegrator<T : Any>(
return integrand + IntegrandValue(res) + IntegrandCallsPerformed(integrand.calls + points.size) return integrand + IntegrandValue(res) + IntegrandCallsPerformed(integrand.calls + points.size)
} }
public companion object { public companion object
}
} }
/** /**
* Integrate [T]-valued univariate function using provided set of [IntegrandFeature] * Create a Gauss-Legendre integrator for this field
* Following features are evaluated: * Following integrand features are accepted:
* * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory] * * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory]
* * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. * * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval.
* * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points. * * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 10 points.
* * [UnivariateIntegrandRanges] - Set of ranges and number of points per range. Defaults to given [IntegrationRange] and [IntegrandMaxCalls] * * [UnivariateIntegrandRanges] - Set of ranges and number of points per range. Defaults to given [IntegrationRange] and [IntegrandMaxCalls]
*/ */
@UnstableKMathAPI public val <T:Any> Field<T>.integrator: GaussIntegrator<T> get() = GaussIntegrator(this)
public fun <T : Any> Field<T>.integrate(
vararg features: IntegrandFeature,
function: (Double) -> T,
): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(UnivariateIntegrand(function, *features))
/** /**
* Use [GaussIntegrator.Companion.integrate] to integrate the function in the current algebra with given [range] and [numPoints] * Use [integrate] to integrate the function in the current algebra with given [range] and [numPoints]
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun <T : Any> Field<T>.integrate( public fun <T : Any> GaussIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
order: Int = 10, order: Int = 10,
intervals: Int = 10, intervals: Int = 10,
@ -104,7 +98,7 @@ public fun <T : Any> Field<T>.integrate(
val ranges = UnivariateIntegrandRanges( val ranges = UnivariateIntegrandRanges(
(0 until intervals).map { i -> (rangeSize * i)..(rangeSize * (i + 1)) to order } (0 until intervals).map { i -> (rangeSize * i)..(rangeSize * (i + 1)) to order }
) )
return GaussIntegrator(this).integrate( return integrate(
UnivariateIntegrand( UnivariateIntegrand(
function, function,
IntegrationRange(range), IntegrationRange(range),

View File

@ -35,37 +35,41 @@ public typealias UnivariateIntegrator<T> = Integrator<UnivariateIntegrand<T>>
@JvmInline @JvmInline
public value class IntegrationRange(public val range: ClosedRange<Double>) : IntegrandFeature public value class IntegrationRange(public val range: ClosedRange<Double>) : IntegrandFeature
public val <T : Any> UnivariateIntegrand<T>.value: T? get() = getFeature<IntegrandValue<T>>()?.value /**
* Value of the integrand if it is present or null
*/
public val <T : Any> UnivariateIntegrand<T>.valueOrNull: T? get() = getFeature<IntegrandValue<T>>()?.value
/**
* Value of the integrand or error
*/
public val <T : Any> UnivariateIntegrand<T>.value: T get() = valueOrNull ?: error("No value in the integrand")
/** /**
* A shortcut method to integrate a [function] in [range] with additional [features]. * A shortcut method to integrate a [function] in [range] with additional [features].
* The [function] is placed in the end position to allow passing a lambda. * The [function] is placed in the end position to allow passing a lambda.
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun UnivariateIntegrator<Double>.integrate( public fun <T: Any> UnivariateIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
vararg features: IntegrandFeature, vararg features: IntegrandFeature,
function: (Double) -> Double, function: (Double) -> T,
): Double = integrate( ): UnivariateIntegrand<T> = integrate(UnivariateIntegrand(function, IntegrationRange(range), *features))
UnivariateIntegrand(function, IntegrationRange(range), *features)
).value ?: error("Unexpected: no value after integration.")
/** /**
* A shortcut method to integrate a [function] in [range] with additional [features]. * A shortcut method to integrate a [function] in [range] with additional [features].
* The [function] is placed in the end position to allow passing a lambda. * The [function] is placed in the end position to allow passing a lambda.
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun UnivariateIntegrator<Double>.integrate( public fun <T: Any> UnivariateIntegrator<T>.integrate(
range: ClosedRange<Double>, range: ClosedRange<Double>,
featureBuilder: MutableList<IntegrandFeature>.() -> Unit = {}, featureBuilder: MutableList<IntegrandFeature>.() -> Unit = {},
function: (Double) -> Double, function: (Double) -> T,
): Double { ): UnivariateIntegrand<T> {
//TODO use dedicated feature builder class instead or add extensions to MutableList<IntegrandFeature> //TODO use dedicated feature builder class instead or add extensions to MutableList<IntegrandFeature>
val features = buildList { val features = buildList {
featureBuilder() featureBuilder()
add(IntegrationRange(range)) add(IntegrationRange(range))
} }
return integrate( return integrate(UnivariateIntegrand(function, *features.toTypedArray()))
UnivariateIntegrand(function, *features.toTypedArray())
).value ?: error("Unexpected: no value after integration.")
} }

View File

@ -19,7 +19,7 @@ class GaussIntegralTest {
val res = DoubleField.integrate(0.0..2 * PI) { x -> val res = DoubleField.integrate(0.0..2 * PI) { x ->
sin(x) sin(x)
} }
assertEquals(0.0, res.value!!, 1e-2) assertEquals(0.0, res.valueOrNull!!, 1e-2)
} }
@Test @Test
@ -31,7 +31,7 @@ class GaussIntegralTest {
0.0 0.0
} }
} }
assertEquals(20.0, res.value!!, 0.5) assertEquals(20.0, res.valueOrNull!!, 0.5)
} }