Update integration API

This commit is contained in:
Alexander Nozik 2021-04-16 16:39:27 +03:00
parent 65a8d8f581
commit ef1200aad0
6 changed files with 79 additions and 129 deletions

View File

@ -1,7 +1,8 @@
package space.kscience.kmath.functions package space.kscience.kmath.functions
import space.kscience.kmath.integration.GaussIntegrator import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.value import space.kscience.kmath.integration.value
import space.kscience.kmath.operations.DoubleField
import kotlin.math.pow import kotlin.math.pow
fun main() { fun main() {
@ -9,7 +10,7 @@ fun main() {
val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 } val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
//get the result of the integration //get the result of the integration
val result = GaussIntegrator.legendre(0.0..10.0, function = function) val result = DoubleField.integrate(0.0..10.0, function = function)
//the value is nullable because in some cases the integration could not succeed //the value is nullable because in some cases the integration could not succeed
println(result.value) println(result.value)

View File

@ -1,7 +1,6 @@
package space.kscience.kmath.functions package space.kscience.kmath.functions
import space.kscience.kmath.integration.GaussIntegrator import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.UnivariateIntegrand
import space.kscience.kmath.integration.value import space.kscience.kmath.integration.value
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.nd import space.kscience.kmath.nd.nd
@ -10,11 +9,17 @@ import space.kscience.kmath.operations.invoke
fun main(): Unit = DoubleField { fun main(): Unit = DoubleField {
nd(2, 2) { nd(2, 2) {
//Produce a diagonal StructureND
fun diagonal(v: Double) = produce { (i, j) ->
if (i == j) v else 0.0
}
//Define a function in a nd space //Define a function in a nd space
val function: UnivariateFunction<StructureND<Double>> = { x -> 3 * x.pow(2) + 2 * x + 1 } val function: (Double) -> StructureND<Double> = { x: Double -> 3 * number(x).pow(2) + 2 * diagonal(x) + 1 }
//get the result of the integration //get the result of the integration
val result: UnivariateIntegrand<StructureND<Double>> = GaussIntegrator.legendre(this, 0.0..10.0, function = function) val result = integrate(0.0..10.0, function = function)
//the value is nullable because in some cases the integration could not succeed //the value is nullable because in some cases the integration could not succeed
println(result.value) println(result.value)

View File

@ -5,7 +5,6 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Field import space.kscience.kmath.operations.Field
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
@ -15,12 +14,10 @@ import space.kscience.kmath.structures.*
*/ */
public class GaussIntegrator<T : Any>( public class GaussIntegrator<T : Any>(
public val algebra: Field<T>, public val algebra: Field<T>,
public val bufferFactory: BufferFactory<T>,
) : UnivariateIntegrator<T> { ) : UnivariateIntegrator<T> {
private fun buildRule(integrand: UnivariateIntegrand<T>): Pair<Buffer<T>, Buffer<T>> { private fun buildRule(integrand: UnivariateIntegrand<T>): Pair<Buffer<Double>, Buffer<Double>> {
val factory = integrand.getFeature<GaussIntegratorRuleFactory<T>>() val factory = integrand.getFeature<GaussIntegratorRuleFactory>() ?: GaussLegendreRuleFactory
?: GenericGaussLegendreRuleFactory(algebra, bufferFactory)
val numPoints = integrand.getFeature<IntegrandMaxCalls>()?.maxCalls ?: 100 val numPoints = integrand.getFeature<IntegrandMaxCalls>()?.maxCalls ?: 100
val range = integrand.getFeature<IntegrationRange<Double>>()?.range ?: 0.0..1.0 val range = integrand.getFeature<IntegrationRange<Double>>()?.range ?: 0.0..1.0
return factory.build(numPoints, range) return factory.build(numPoints, range)
@ -32,9 +29,9 @@ public class GaussIntegrator<T : Any>(
var res = zero var res = zero
var c = zero var c = zero
for (i in points.indices) { for (i in points.indices) {
val x: T = points[i] val x = points[i]
val w: T = weights[i] val weight = weights[i]
val y: T = w * f(x) - c val y: T = weight * f(x) - c
val t = res + y val t = res + y
c = t - res - y c = t - res - y
res = t res = t
@ -44,68 +41,38 @@ public class GaussIntegrator<T : Any>(
public companion object { public companion object {
}
}
/** /**
* Integrate [T]-valued univariate function using provided set of [IntegrandFeature] * Integrate [T]-valued univariate function using provided set of [IntegrandFeature]
* Following features are evaluated: * Following features are evaluated:
* * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GenericGaussLegendreRuleFactory] * * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory]
* * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. * * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval.
* * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points. * * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points.
*/ */
public fun <T : Any> integrate( @UnstableKMathAPI
algebra: Field<T>, public fun <T : Any> Field<T>.integrate(
bufferFactory: BufferFactory<T> = ::ListBuffer,
vararg features: IntegrandFeature, vararg features: IntegrandFeature,
function: (T) -> T, function: (Double) -> T,
): UnivariateIntegrand<T> = ): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(UnivariateIntegrand(function, *features))
GaussIntegrator(algebra, bufferFactory).integrate(UnivariateIntegrand(function, *features))
/** /**
* Integrate in real numbers * Use [GaussIntegrator.Companion.integrate] to integrate the function in the current algebra with given [range] and [numPoints]
*/
public fun integrate(
vararg features: IntegrandFeature,
function: (Double) -> Double,
): UnivariateIntegrand<Double> = integrate(DoubleField, ::DoubleBuffer, features = features, function)
/**
* Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points.
* The [range] is automatically transformed into [T] using [Field.number]
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public fun <T : Any> legendre( public fun <T : Any> Field<T>.integrate(
algebra: Field<T>,
range: ClosedRange<Double>, range: ClosedRange<Double>,
numPoints: Int = 100, numPoints: Int = 100,
bufferFactory: BufferFactory<T> = ::ListBuffer,
vararg features: IntegrandFeature, vararg features: IntegrandFeature,
function: (T) -> T, function: (Double) -> T,
): UnivariateIntegrand<T> = GaussIntegrator(algebra, bufferFactory).integrate( ): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(
UnivariateIntegrand( UnivariateIntegrand(
function, function,
IntegrationRange(range), IntegrationRange(range),
DoubleGaussLegendreRuleFactory, GaussLegendreRuleFactory,
IntegrandMaxCalls(numPoints), IntegrandMaxCalls(numPoints),
*features *features
) )
) )
/**
* Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points.
*/
@UnstableKMathAPI
public fun legendre(
range: ClosedRange<Double>,
numPoints: Int = 100,
vararg features: IntegrandFeature,
function: (Double) -> Double,
): UnivariateIntegrand<Double> = GaussIntegrator(DoubleField, ::DoubleBuffer).integrate(
UnivariateIntegrand(
function,
IntegrationRange(range),
DoubleGaussLegendreRuleFactory,
IntegrandMaxCalls(numPoints),
*features
)
)
}
}

View File

@ -5,22 +5,20 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.operations.Field import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.asBuffer
import space.kscience.kmath.structures.map
import kotlin.jvm.Synchronized import kotlin.jvm.Synchronized
import kotlin.math.ulp import kotlin.math.ulp
import kotlin.native.concurrent.ThreadLocal import kotlin.native.concurrent.ThreadLocal
public interface GaussIntegratorRuleFactory<T : Any> : IntegrandFeature { public interface GaussIntegratorRuleFactory : IntegrandFeature {
public val algebra: Field<T> public fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>>
public val bufferFactory: BufferFactory<T>
public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>>
public companion object { public companion object {
public fun double(numPoints: Int, range: ClosedRange<Double>): Pair<Buffer<Double>, Buffer<Double>> = public fun double(numPoints: Int, range: ClosedRange<Double>): Pair<Buffer<Double>, Buffer<Double>> =
DoubleGaussLegendreRuleFactory.build(numPoints, range) GaussLegendreRuleFactory.build(numPoints, range)
} }
} }
@ -28,24 +26,23 @@ public interface GaussIntegratorRuleFactory<T : Any> : IntegrandFeature {
* Create an integration rule by scaling existing normalized rule * Create an integration rule by scaling existing normalized rule
* *
*/ */
public fun <T : Any> GaussIntegratorRuleFactory<T>.build( public fun GaussIntegratorRuleFactory.build(
numPoints: Int, numPoints: Int,
range: ClosedRange<Double>, range: ClosedRange<Double>,
): Pair<Buffer<T>, Buffer<T>> { ): Pair<Buffer<Double>, Buffer<Double>> {
val normalized = build(numPoints) val normalized = build(numPoints)
with(algebra) {
val length = range.endInclusive - range.start val length = range.endInclusive - range.start
val points = normalized.first.map(bufferFactory) { val points = normalized.first.map(::DoubleBuffer) {
number(range.start + length / 2) + number(length / 2) * it range.start + length / 2 + length / 2 * it
} }
val weights = normalized.second.map(bufferFactory) { val weights = normalized.second.map(::DoubleBuffer) {
it * length / 2 it * length / 2
} }
return points to weights return points to weights
}
} }
@ -56,11 +53,7 @@ public fun <T : Any> GaussIntegratorRuleFactory<T>.build(
* *
*/ */
@ThreadLocal @ThreadLocal
public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory<Double> { public object GaussLegendreRuleFactory : GaussIntegratorRuleFactory {
override val algebra: Field<Double> get() = DoubleField
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
private val cache = HashMap<Int, Pair<Buffer<Double>, Buffer<Double>>>() private val cache = HashMap<Int, Pair<Buffer<Double>, Buffer<Double>>>()
@ -172,21 +165,3 @@ public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory<Double
override fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>> = getOrBuildRule(numPoints) override fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>> = getOrBuildRule(numPoints)
} }
/**
* A generic Gauss-Legendre rule factory that wraps [DoubleGaussLegendreRuleFactory] in a generic way.
*/
public class GenericGaussLegendreRuleFactory<T : Any>(
override val algebra: Field<T>,
override val bufferFactory: BufferFactory<T>,
) : GaussIntegratorRuleFactory<T> {
override fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>> {
val (doublePoints, doubleWeights) = DoubleGaussLegendreRuleFactory.build(numPoints)
val points = doublePoints.map(bufferFactory) { algebra.number(it) }
val weights = doubleWeights.map(bufferFactory) { algebra.number(it) }
return points to weights
}
}

View File

@ -5,14 +5,13 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.kmath.functions.UnivariateFunction
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
import kotlin.reflect.KClass import kotlin.reflect.KClass
public class UnivariateIntegrand<T : Any> internal constructor( public class UnivariateIntegrand<T : Any> internal constructor(
private val features: Map<KClass<*>, IntegrandFeature>, private val features: Map<KClass<*>, IntegrandFeature>,
public val function: UnivariateFunction<T>, public val function: (Double) -> T,
) : Integrand { ) : Integrand {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
@ -27,7 +26,7 @@ public class UnivariateIntegrand<T : Any> internal constructor(
@Suppress("FunctionName") @Suppress("FunctionName")
public fun <T : Any> UnivariateIntegrand( public fun <T : Any> UnivariateIntegrand(
function: (T) -> T, function: (Double) -> T,
vararg features: IntegrandFeature, vararg features: IntegrandFeature,
): UnivariateIntegrand<T> = UnivariateIntegrand(features.associateBy { it::class }, function) ): UnivariateIntegrand<T> = UnivariateIntegrand(features.associateBy { it::class }, function)

View File

@ -5,15 +5,18 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import kotlin.math.PI import kotlin.math.PI
import kotlin.math.sin import kotlin.math.sin
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@OptIn(UnstableKMathAPI::class)
class GaussIntegralTest { class GaussIntegralTest {
@Test @Test
fun gaussSin() { fun gaussSin() {
val res = GaussIntegrator.legendre(0.0..2 * PI) { x -> val res = DoubleField.integrate(0.0..2 * PI) { x ->
sin(x) sin(x)
} }
assertEquals(0.0, res.value!!, 1e-4) assertEquals(0.0, res.value!!, 1e-4)
@ -21,7 +24,7 @@ class GaussIntegralTest {
@Test @Test
fun gaussUniform() { fun gaussUniform() {
val res = GaussIntegrator.legendre(0.0..100.0,300) { x -> val res = DoubleField.integrate(0.0..100.0,300) { x ->
if(x in 30.0..50.0){ if(x in 30.0..50.0){
1.0 1.0
} else { } else {