forked from kscience/kmath
Update integration API
This commit is contained in:
parent
65a8d8f581
commit
ef1200aad0
@ -1,7 +1,8 @@
|
|||||||
package space.kscience.kmath.functions
|
package space.kscience.kmath.functions
|
||||||
|
|
||||||
import space.kscience.kmath.integration.GaussIntegrator
|
import space.kscience.kmath.integration.integrate
|
||||||
import space.kscience.kmath.integration.value
|
import space.kscience.kmath.integration.value
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
@ -9,7 +10,7 @@ fun main() {
|
|||||||
val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
val function: UnivariateFunction<Double> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
||||||
|
|
||||||
//get the result of the integration
|
//get the result of the integration
|
||||||
val result = GaussIntegrator.legendre(0.0..10.0, function = function)
|
val result = DoubleField.integrate(0.0..10.0, function = function)
|
||||||
|
|
||||||
//the value is nullable because in some cases the integration could not succeed
|
//the value is nullable because in some cases the integration could not succeed
|
||||||
println(result.value)
|
println(result.value)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package space.kscience.kmath.functions
|
package space.kscience.kmath.functions
|
||||||
|
|
||||||
import space.kscience.kmath.integration.GaussIntegrator
|
import space.kscience.kmath.integration.integrate
|
||||||
import space.kscience.kmath.integration.UnivariateIntegrand
|
|
||||||
import space.kscience.kmath.integration.value
|
import space.kscience.kmath.integration.value
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.nd
|
import space.kscience.kmath.nd.nd
|
||||||
@ -10,11 +9,17 @@ import space.kscience.kmath.operations.invoke
|
|||||||
|
|
||||||
fun main(): Unit = DoubleField {
|
fun main(): Unit = DoubleField {
|
||||||
nd(2, 2) {
|
nd(2, 2) {
|
||||||
|
|
||||||
|
//Produce a diagonal StructureND
|
||||||
|
fun diagonal(v: Double) = produce { (i, j) ->
|
||||||
|
if (i == j) v else 0.0
|
||||||
|
}
|
||||||
|
|
||||||
//Define a function in a nd space
|
//Define a function in a nd space
|
||||||
val function: UnivariateFunction<StructureND<Double>> = { x -> 3 * x.pow(2) + 2 * x + 1 }
|
val function: (Double) -> StructureND<Double> = { x: Double -> 3 * number(x).pow(2) + 2 * diagonal(x) + 1 }
|
||||||
|
|
||||||
//get the result of the integration
|
//get the result of the integration
|
||||||
val result: UnivariateIntegrand<StructureND<Double>> = GaussIntegrator.legendre(this, 0.0..10.0, function = function)
|
val result = integrate(0.0..10.0, function = function)
|
||||||
|
|
||||||
//the value is nullable because in some cases the integration could not succeed
|
//the value is nullable because in some cases the integration could not succeed
|
||||||
println(result.value)
|
println(result.value)
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
|
|
||||||
@ -15,12 +14,10 @@ import space.kscience.kmath.structures.*
|
|||||||
*/
|
*/
|
||||||
public class GaussIntegrator<T : Any>(
|
public class GaussIntegrator<T : Any>(
|
||||||
public val algebra: Field<T>,
|
public val algebra: Field<T>,
|
||||||
public val bufferFactory: BufferFactory<T>,
|
|
||||||
) : UnivariateIntegrator<T> {
|
) : UnivariateIntegrator<T> {
|
||||||
|
|
||||||
private fun buildRule(integrand: UnivariateIntegrand<T>): Pair<Buffer<T>, Buffer<T>> {
|
private fun buildRule(integrand: UnivariateIntegrand<T>): Pair<Buffer<Double>, Buffer<Double>> {
|
||||||
val factory = integrand.getFeature<GaussIntegratorRuleFactory<T>>()
|
val factory = integrand.getFeature<GaussIntegratorRuleFactory>() ?: GaussLegendreRuleFactory
|
||||||
?: GenericGaussLegendreRuleFactory(algebra, bufferFactory)
|
|
||||||
val numPoints = integrand.getFeature<IntegrandMaxCalls>()?.maxCalls ?: 100
|
val numPoints = integrand.getFeature<IntegrandMaxCalls>()?.maxCalls ?: 100
|
||||||
val range = integrand.getFeature<IntegrationRange<Double>>()?.range ?: 0.0..1.0
|
val range = integrand.getFeature<IntegrationRange<Double>>()?.range ?: 0.0..1.0
|
||||||
return factory.build(numPoints, range)
|
return factory.build(numPoints, range)
|
||||||
@ -32,9 +29,9 @@ public class GaussIntegrator<T : Any>(
|
|||||||
var res = zero
|
var res = zero
|
||||||
var c = zero
|
var c = zero
|
||||||
for (i in points.indices) {
|
for (i in points.indices) {
|
||||||
val x: T = points[i]
|
val x = points[i]
|
||||||
val w: T = weights[i]
|
val weight = weights[i]
|
||||||
val y: T = w * f(x) - c
|
val y: T = weight * f(x) - c
|
||||||
val t = res + y
|
val t = res + y
|
||||||
c = t - res - y
|
c = t - res - y
|
||||||
res = t
|
res = t
|
||||||
@ -44,68 +41,38 @@ public class GaussIntegrator<T : Any>(
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
|
|
||||||
/**
|
|
||||||
* Integrate [T]-valued univariate function using provided set of [IntegrandFeature]
|
|
||||||
* Following features are evaluated:
|
|
||||||
* * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GenericGaussLegendreRuleFactory]
|
|
||||||
* * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval.
|
|
||||||
* * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points.
|
|
||||||
*/
|
|
||||||
public fun <T : Any> integrate(
|
|
||||||
algebra: Field<T>,
|
|
||||||
bufferFactory: BufferFactory<T> = ::ListBuffer,
|
|
||||||
vararg features: IntegrandFeature,
|
|
||||||
function: (T) -> T,
|
|
||||||
): UnivariateIntegrand<T> =
|
|
||||||
GaussIntegrator(algebra, bufferFactory).integrate(UnivariateIntegrand(function, *features))
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Integrate in real numbers
|
|
||||||
*/
|
|
||||||
public fun integrate(
|
|
||||||
vararg features: IntegrandFeature,
|
|
||||||
function: (Double) -> Double,
|
|
||||||
): UnivariateIntegrand<Double> = integrate(DoubleField, ::DoubleBuffer, features = features, function)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points.
|
|
||||||
* The [range] is automatically transformed into [T] using [Field.number]
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T : Any> legendre(
|
|
||||||
algebra: Field<T>,
|
|
||||||
range: ClosedRange<Double>,
|
|
||||||
numPoints: Int = 100,
|
|
||||||
bufferFactory: BufferFactory<T> = ::ListBuffer,
|
|
||||||
vararg features: IntegrandFeature,
|
|
||||||
function: (T) -> T,
|
|
||||||
): UnivariateIntegrand<T> = GaussIntegrator(algebra, bufferFactory).integrate(
|
|
||||||
UnivariateIntegrand(
|
|
||||||
function,
|
|
||||||
IntegrationRange(range),
|
|
||||||
DoubleGaussLegendreRuleFactory,
|
|
||||||
IntegrandMaxCalls(numPoints),
|
|
||||||
*features
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points.
|
|
||||||
*/
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun legendre(
|
|
||||||
range: ClosedRange<Double>,
|
|
||||||
numPoints: Int = 100,
|
|
||||||
vararg features: IntegrandFeature,
|
|
||||||
function: (Double) -> Double,
|
|
||||||
): UnivariateIntegrand<Double> = GaussIntegrator(DoubleField, ::DoubleBuffer).integrate(
|
|
||||||
UnivariateIntegrand(
|
|
||||||
function,
|
|
||||||
IntegrationRange(range),
|
|
||||||
DoubleGaussLegendreRuleFactory,
|
|
||||||
IntegrandMaxCalls(numPoints),
|
|
||||||
*features
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Integrate [T]-valued univariate function using provided set of [IntegrandFeature]
|
||||||
|
* Following features are evaluated:
|
||||||
|
* * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory]
|
||||||
|
* * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval.
|
||||||
|
* * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T : Any> Field<T>.integrate(
|
||||||
|
vararg features: IntegrandFeature,
|
||||||
|
function: (Double) -> T,
|
||||||
|
): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(UnivariateIntegrand(function, *features))
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use [GaussIntegrator.Companion.integrate] to integrate the function in the current algebra with given [range] and [numPoints]
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T : Any> Field<T>.integrate(
|
||||||
|
range: ClosedRange<Double>,
|
||||||
|
numPoints: Int = 100,
|
||||||
|
vararg features: IntegrandFeature,
|
||||||
|
function: (Double) -> T,
|
||||||
|
): UnivariateIntegrand<T> = GaussIntegrator(this).integrate(
|
||||||
|
UnivariateIntegrand(
|
||||||
|
function,
|
||||||
|
IntegrationRange(range),
|
||||||
|
GaussLegendreRuleFactory,
|
||||||
|
IntegrandMaxCalls(numPoints),
|
||||||
|
*features
|
||||||
|
)
|
||||||
|
)
|
@ -5,22 +5,20 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.asBuffer
|
||||||
|
import space.kscience.kmath.structures.map
|
||||||
import kotlin.jvm.Synchronized
|
import kotlin.jvm.Synchronized
|
||||||
import kotlin.math.ulp
|
import kotlin.math.ulp
|
||||||
import kotlin.native.concurrent.ThreadLocal
|
import kotlin.native.concurrent.ThreadLocal
|
||||||
|
|
||||||
public interface GaussIntegratorRuleFactory<T : Any> : IntegrandFeature {
|
public interface GaussIntegratorRuleFactory : IntegrandFeature {
|
||||||
public val algebra: Field<T>
|
public fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>>
|
||||||
public val bufferFactory: BufferFactory<T>
|
|
||||||
|
|
||||||
public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>>
|
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun double(numPoints: Int, range: ClosedRange<Double>): Pair<Buffer<Double>, Buffer<Double>> =
|
public fun double(numPoints: Int, range: ClosedRange<Double>): Pair<Buffer<Double>, Buffer<Double>> =
|
||||||
DoubleGaussLegendreRuleFactory.build(numPoints, range)
|
GaussLegendreRuleFactory.build(numPoints, range)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,24 +26,23 @@ public interface GaussIntegratorRuleFactory<T : Any> : IntegrandFeature {
|
|||||||
* Create an integration rule by scaling existing normalized rule
|
* Create an integration rule by scaling existing normalized rule
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public fun <T : Any> GaussIntegratorRuleFactory<T>.build(
|
public fun GaussIntegratorRuleFactory.build(
|
||||||
numPoints: Int,
|
numPoints: Int,
|
||||||
range: ClosedRange<Double>,
|
range: ClosedRange<Double>,
|
||||||
): Pair<Buffer<T>, Buffer<T>> {
|
): Pair<Buffer<Double>, Buffer<Double>> {
|
||||||
val normalized = build(numPoints)
|
val normalized = build(numPoints)
|
||||||
with(algebra) {
|
val length = range.endInclusive - range.start
|
||||||
val length = range.endInclusive - range.start
|
|
||||||
|
|
||||||
val points = normalized.first.map(bufferFactory) {
|
val points = normalized.first.map(::DoubleBuffer) {
|
||||||
number(range.start + length / 2) + number(length / 2) * it
|
range.start + length / 2 + length / 2 * it
|
||||||
}
|
|
||||||
|
|
||||||
val weights = normalized.second.map(bufferFactory) {
|
|
||||||
it * length / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
return points to weights
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val weights = normalized.second.map(::DoubleBuffer) {
|
||||||
|
it * length / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return points to weights
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -56,11 +53,7 @@ public fun <T : Any> GaussIntegratorRuleFactory<T>.build(
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@ThreadLocal
|
@ThreadLocal
|
||||||
public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory<Double> {
|
public object GaussLegendreRuleFactory : GaussIntegratorRuleFactory {
|
||||||
|
|
||||||
override val algebra: Field<Double> get() = DoubleField
|
|
||||||
|
|
||||||
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
|
|
||||||
|
|
||||||
private val cache = HashMap<Int, Pair<Buffer<Double>, Buffer<Double>>>()
|
private val cache = HashMap<Int, Pair<Buffer<Double>, Buffer<Double>>>()
|
||||||
|
|
||||||
@ -172,21 +165,3 @@ public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory<Double
|
|||||||
|
|
||||||
override fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>> = getOrBuildRule(numPoints)
|
override fun build(numPoints: Int): Pair<Buffer<Double>, Buffer<Double>> = getOrBuildRule(numPoints)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A generic Gauss-Legendre rule factory that wraps [DoubleGaussLegendreRuleFactory] in a generic way.
|
|
||||||
*/
|
|
||||||
public class GenericGaussLegendreRuleFactory<T : Any>(
|
|
||||||
override val algebra: Field<T>,
|
|
||||||
override val bufferFactory: BufferFactory<T>,
|
|
||||||
) : GaussIntegratorRuleFactory<T> {
|
|
||||||
|
|
||||||
override fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>> {
|
|
||||||
val (doublePoints, doubleWeights) = DoubleGaussLegendreRuleFactory.build(numPoints)
|
|
||||||
|
|
||||||
val points = doublePoints.map(bufferFactory) { algebra.number(it) }
|
|
||||||
val weights = doubleWeights.map(bufferFactory) { algebra.number(it) }
|
|
||||||
return points to weights
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -5,14 +5,13 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
import space.kscience.kmath.functions.UnivariateFunction
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
public class UnivariateIntegrand<T : Any> internal constructor(
|
public class UnivariateIntegrand<T : Any> internal constructor(
|
||||||
private val features: Map<KClass<*>, IntegrandFeature>,
|
private val features: Map<KClass<*>, IntegrandFeature>,
|
||||||
public val function: UnivariateFunction<T>,
|
public val function: (Double) -> T,
|
||||||
) : Integrand {
|
) : Integrand {
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
@ -27,7 +26,7 @@ public class UnivariateIntegrand<T : Any> internal constructor(
|
|||||||
|
|
||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
public fun <T : Any> UnivariateIntegrand(
|
public fun <T : Any> UnivariateIntegrand(
|
||||||
function: (T) -> T,
|
function: (Double) -> T,
|
||||||
vararg features: IntegrandFeature,
|
vararg features: IntegrandFeature,
|
||||||
): UnivariateIntegrand<T> = UnivariateIntegrand(features.associateBy { it::class }, function)
|
): UnivariateIntegrand<T> = UnivariateIntegrand(features.associateBy { it::class }, function)
|
||||||
|
|
||||||
|
@ -5,15 +5,18 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.math.sin
|
import kotlin.math.sin
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
class GaussIntegralTest {
|
class GaussIntegralTest {
|
||||||
@Test
|
@Test
|
||||||
fun gaussSin() {
|
fun gaussSin() {
|
||||||
val res = GaussIntegrator.legendre(0.0..2 * PI) { x ->
|
val res = DoubleField.integrate(0.0..2 * PI) { x ->
|
||||||
sin(x)
|
sin(x)
|
||||||
}
|
}
|
||||||
assertEquals(0.0, res.value!!, 1e-4)
|
assertEquals(0.0, res.value!!, 1e-4)
|
||||||
@ -21,7 +24,7 @@ class GaussIntegralTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun gaussUniform() {
|
fun gaussUniform() {
|
||||||
val res = GaussIntegrator.legendre(0.0..100.0,300) { x ->
|
val res = DoubleField.integrate(0.0..100.0,300) { x ->
|
||||||
if(x in 30.0..50.0){
|
if(x in 30.0..50.0){
|
||||||
1.0
|
1.0
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user