Update integration API
This commit is contained in:
parent
65a8d8f581
commit
ef1200aad0
@ -1,7 +1,8 @@
|
||||
package space.kscience.kmath.functions
|
||||
|
||||
import space.kscience.kmath.integration.GaussIntegrator
|
||||
import space.kscience.kmath.integration.integrate
|
||||
import space.kscience.kmath.integration.value
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import kotlin.math.pow
|
||||
|
||||
fun main() {
|
||||
@ -9,7 +10,7 @@ fun main() {
|
||||
val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
||||
|
||||
//get the result of the integration
|
||||
val result = GaussIntegrator.legendre(0.0..10.0, function = function)
|
||||
val result = DoubleField.integrate(0.0..10.0, function = function)
|
||||
|
||||
//the value is nullable because in some cases the integration could not succeed
|
||||
println(result.value)
|
||||
|
@ -1,7 +1,6 @@
|
||||
package space.kscience.kmath.functions
|
||||
|
||||
import space.kscience.kmath.integration.GaussIntegrator
|
||||
import space.kscience.kmath.integration.UnivariateIntegrand
|
||||
import space.kscience.kmath.integration.integrate
|
||||
import space.kscience.kmath.integration.value
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.nd
|
||||
@ -10,11 +9,17 @@ import space.kscience.kmath.operations.invoke
|
||||
|
||||
fun main(): Unit = DoubleField {
|
||||
nd(2, 2) {
|
||||
|
||||
//Produce a diagonal StructureND
|
||||
fun diagonal(v: Double) = produce { (i, j) ->
|
||||
if (i == j) v else 0.0
|
||||
}
|
||||
|
||||
//Define a function in a nd space
|
||||
val function: UnivariateFunction<StructureND<Double>> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
||||
val function: (Double) -> StructureND<Double> = { x: Double -> 3 * number(x).pow(2) + 2 * diagonal(x) + 1 }
|
||||
|
||||
//get the result of the integration
|
||||
val result: UnivariateIntegrand<StructureND<Double>> = GaussIntegrator.legendre(this, 0.0..10.0, function = function)
|
||||
val result = integrate(0.0..10.0, function = function)
|
||||
|
||||
//the value is nullable because in some cases the integration could not succeed
|
||||
println(result.value)
|
||||
|
@ -5,7 +5,6 @@
|
||||
package space.kscience.kmath.integration
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.Field
|
||||
import space.kscience.kmath.structures.*
|
||||
|
||||
@ -15,12 +14,10 @@ import space.kscience.kmath.structures.*
|
||||
*/
|
||||
public class GaussIntegrator<T : Any>(
|
||||
public val algebra: Field<T>,
|
||||
public val bufferFactory: BufferFactory<T>,
|
||||
) : UnivariateIntegrator<T> {
|
||||
|
||||
private fun buildRule(integrand: UnivariateIntegrand<T>): Pair<Buffer<T>, Buffer<T>> {
|
||||
val factory = integrand.getFeature<GaussIntegratorRuleFactory<T>>()
|
||||
?: GenericGaussLegendreRuleFactory(algebra, bufferFactory)
|
||||
private fun buildRule(integrand: UnivariateIntegrand<T>): Pair<Buffer<Double>, Buffer<Double>> {
|
||||
val factory = integrand.getFeature<GaussIntegratorRuleFactory>() ?: GaussLegendreRuleFactory
|
||||
val numPoints = integrand.getFeature<IntegrandMaxCalls>()?.maxCalls ?: 100
|
||||
val range = integrand.getFeature<IntegrationRange<Double>>()?.range ?: 0.0..1.0
|
||||
return factory.build(numPoints, range)
|
||||
@ -32,9 +29,9 @@ public class GaussIntegrator<T : Any>(
|
||||
var res = zero
|
||||
var c = zero
|
||||
for (i in points.indices) {
|
||||
val x: T = points[i]
|
||||
val w: T = weights[i]
|
||||
val y: T = w * f(x) - c
|
||||
val x = points[i]
|
||||
val weight = weights[i]
|
||||
val y: T = weight * f(x) - c
|
||||
val t = res + y
|
||||
c = t - res - y
|
||||
res = t
|
||||
@ -44,68 +41,38 @@ public class GaussIntegrator<T : Any>(
|
||||
|
||||
public companion object {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Integrate [T]-valued univariate function using provided set of [IntegrandFeature]
|
||||
* Following features are evaluated:
|
||||
* * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GenericGaussLegendreRuleFactory]
|
||||
* * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory]
|
||||
* * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval.
|
||||
* * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points.
|
||||
*/
|
||||
public fun <T : Any> integrate(
|
||||
algebra: Field<T>,
|
||||
bufferFactory: BufferFactory<T> = ::ListBuffer,
|
||||
@UnstableKMathAPI
|
||||
public fun <T : Any> Field<T>.integrate(
|
||||
vararg features: IntegrandFeature,
|
||||
function: (T) -> T,
|
||||
): UnivariateIntegrand<T> =
|
||||
GaussIntegrator(algebra, bufferFactory).integrate(UnivariateIntegrand(function, *features))
|
||||
function: (Double) -> T,
|
||||
): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(UnivariateIntegrand(function, *features))
|
||||
|
||||
|
||||
/**
|
||||
* Integrate in real numbers
|
||||
*/
|
||||
public fun integrate(
|
||||
vararg features: IntegrandFeature,
|
||||
function: (Double) -> Double,
|
||||
): UnivariateIntegrand<Double> = integrate(DoubleField, ::DoubleBuffer, features = features, function)
|
||||
|
||||
/**
|
||||
* Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points.
|
||||
* The [range] is automatically transformed into [T] using [Field.number]
|
||||
* Use [GaussIntegrator.Companion.integrate] to integrate the function in the current algebra with given [range] and [numPoints]
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T : Any> legendre(
|
||||
algebra: Field<T>,
|
||||
public fun <T : Any> Field<T>.integrate(
|
||||
range: ClosedRange<Double>,
|
||||
numPoints: Int = 100,
|
||||
bufferFactory: BufferFactory<T> = ::ListBuffer,
|
||||
vararg features: IntegrandFeature,
|
||||
function: (T) -> T,
|
||||
): UnivariateIntegrand<T> = GaussIntegrator(algebra, bufferFactory).integrate(
|
||||
function: (Double) -> T,
|
||||
): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(
|
||||
UnivariateIntegrand(
|
||||
function,
|
||||
IntegrationRange(range),
|
||||
DoubleGaussLegendreRuleFactory,
|
||||
GaussLegendreRuleFactory,
|
||||
IntegrandMaxCalls(numPoints),
|
||||
*features
|
||||
)
|
||||
)
|
||||
|
||||
/**
|
||||
* Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun legendre(
|
||||
range: ClosedRange<Double>,
|
||||
numPoints: Int = 100,
|
||||
vararg features: IntegrandFeature,
|
||||
function: (Double) -> Double,
|
||||
): UnivariateIntegrand<Double> = GaussIntegrator(DoubleField, ::DoubleBuffer).integrate(
|
||||
UnivariateIntegrand(
|
||||
function,
|
||||
IntegrationRange(range),
|
||||
DoubleGaussLegendreRuleFactory,
|
||||
IntegrandMaxCalls(numPoints),
|
||||
*features
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
@ -5,22 +5,20 @@
|
||||
|
||||
package space.kscience.kmath.integration
|
||||
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.Field
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.structures.asBuffer
|
||||
import space.kscience.kmath.structures.map
|
||||
import kotlin.jvm.Synchronized
|
||||
import kotlin.math.ulp
|
||||
import kotlin.native.concurrent.ThreadLocal
|
||||
|
||||
public interface GaussIntegratorRuleFactory<T : Any> : IntegrandFeature {
|
||||
public val algebra: Field<T>
|
||||
public val bufferFactory: BufferFactory<T>
|
||||
|
||||
public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>>
|
||||
public interface GaussIntegratorRuleFactory : IntegrandFeature {
|
||||
public fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>>
|
||||
|
||||
public companion object {
|
||||
public fun double(numPoints: Int, range: ClosedRange<Double>): Pair<Buffer<Double>, Buffer<Double>> =
|
||||
DoubleGaussLegendreRuleFactory.build(numPoints, range)
|
||||
GaussLegendreRuleFactory.build(numPoints, range)
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,24 +26,23 @@ public interface GaussIntegratorRuleFactory<T : Any> : IntegrandFeature {
|
||||
* Create an integration rule by scaling existing normalized rule
|
||||
*
|
||||
*/
|
||||
public fun <T : Any> GaussIntegratorRuleFactory<T>.build(
|
||||
public fun GaussIntegratorRuleFactory.build(
|
||||
numPoints: Int,
|
||||
range: ClosedRange<Double>,
|
||||
): Pair<Buffer<T>, Buffer<T>> {
|
||||
): Pair<Buffer<Double>, Buffer<Double>> {
|
||||
val normalized = build(numPoints)
|
||||
with(algebra) {
|
||||
val length = range.endInclusive - range.start
|
||||
|
||||
val points = normalized.first.map(bufferFactory) {
|
||||
number(range.start + length / 2) + number(length / 2) * it
|
||||
val points = normalized.first.map(::DoubleBuffer) {
|
||||
range.start + length / 2 + length / 2 * it
|
||||
}
|
||||
|
||||
val weights = normalized.second.map(bufferFactory) {
|
||||
val weights = normalized.second.map(::DoubleBuffer) {
|
||||
it * length / 2
|
||||
}
|
||||
|
||||
return points to weights
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -56,11 +53,7 @@ public fun <T : Any> GaussIntegratorRuleFactory<T>.build(
|
||||
*
|
||||
*/
|
||||
@ThreadLocal
|
||||
public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory<Double> {
|
||||
|
||||
override val algebra: Field<Double> get() = DoubleField
|
||||
|
||||
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
|
||||
public object GaussLegendreRuleFactory : GaussIntegratorRuleFactory {
|
||||
|
||||
private val cache = HashMap<Int, Pair<Buffer<Double>, Buffer<Double>>>()
|
||||
|
||||
@ -172,21 +165,3 @@ public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory<Double
|
||||
|
||||
override fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>> = getOrBuildRule(numPoints)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A generic Gauss-Legendre rule factory that wraps [DoubleGaussLegendreRuleFactory] in a generic way.
|
||||
*/
|
||||
public class GenericGaussLegendreRuleFactory<T : Any>(
|
||||
override val algebra: Field<T>,
|
||||
override val bufferFactory: BufferFactory<T>,
|
||||
) : GaussIntegratorRuleFactory<T> {
|
||||
|
||||
override fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>> {
|
||||
val (doublePoints, doubleWeights) = DoubleGaussLegendreRuleFactory.build(numPoints)
|
||||
|
||||
val points = doublePoints.map(bufferFactory) { algebra.number(it) }
|
||||
val weights = doubleWeights.map(bufferFactory) { algebra.number(it) }
|
||||
return points to weights
|
||||
}
|
||||
}
|
||||
|
@ -5,14 +5,13 @@
|
||||
|
||||
package space.kscience.kmath.integration
|
||||
|
||||
import space.kscience.kmath.functions.UnivariateFunction
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import kotlin.jvm.JvmInline
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
public class UnivariateIntegrand<T : Any> internal constructor(
|
||||
private val features: Map<KClass<*>, IntegrandFeature>,
|
||||
public val function: UnivariateFunction<T>,
|
||||
public val function: (Double) -> T,
|
||||
) : Integrand {
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
@ -27,7 +26,7 @@ public class UnivariateIntegrand<T : Any> internal constructor(
|
||||
|
||||
@Suppress("FunctionName")
|
||||
public fun <T : Any> UnivariateIntegrand(
|
||||
function: (T) -> T,
|
||||
function: (Double) -> T,
|
||||
vararg features: IntegrandFeature,
|
||||
): UnivariateIntegrand<T> = UnivariateIntegrand(features.associateBy { it::class }, function)
|
||||
|
||||
|
@ -5,15 +5,18 @@
|
||||
|
||||
package space.kscience.kmath.integration
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import kotlin.math.PI
|
||||
import kotlin.math.sin
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
class GaussIntegralTest {
|
||||
@Test
|
||||
fun gaussSin() {
|
||||
val res = GaussIntegrator.legendre(0.0..2 * PI) { x ->
|
||||
val res = DoubleField.integrate(0.0..2 * PI) { x ->
|
||||
sin(x)
|
||||
}
|
||||
assertEquals(0.0, res.value!!, 1e-4)
|
||||
@ -21,7 +24,7 @@ class GaussIntegralTest {
|
||||
|
||||
@Test
|
||||
fun gaussUniform() {
|
||||
val res = GaussIntegrator.legendre(0.0..100.0,300) { x ->
|
||||
val res = DoubleField.integrate(0.0..100.0,300) { x ->
|
||||
if(x in 30.0..50.0){
|
||||
1.0
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user