From ef1200aad0db776f07d5f6b4c0572d581acd4470 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 16 Apr 2021 16:39:27 +0300 Subject: [PATCH] Update integration API --- .../kscience/kmath/functions/integrate.kt | 5 +- .../kmath/functions/matrixIntegration.kt | 13 +- .../kmath/integration/GaussIntegrator.kt | 111 ++++++------------ .../integration/GaussIntegratorRuleFactory.kt | 67 ++++------- .../kmath/integration/UnivariateIntegrand.kt | 5 +- .../kmath/integration/GaussIntegralTest.kt | 7 +- 6 files changed, 79 insertions(+), 129 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt index 761d006d3..90542adf4 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/integrate.kt @@ -1,7 +1,8 @@ package space.kscience.kmath.functions -import space.kscience.kmath.integration.GaussIntegrator +import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.value +import space.kscience.kmath.operations.DoubleField import kotlin.math.pow fun main() { @@ -9,7 +10,7 @@ fun main() { val function: UnivariateFunction = { x -> 3 * x.pow(2) + 2 * x + 1 } //get the result of the integration - val result = GaussIntegrator.legendre(0.0..10.0, function = function) + val result = DoubleField.integrate(0.0..10.0, function = function) //the value is nullable because in some cases the integration could not succeed println(result.value) diff --git a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt index 5e92ce22a..bd431c22c 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/functions/matrixIntegration.kt @@ -1,7 +1,6 @@ package space.kscience.kmath.functions -import space.kscience.kmath.integration.GaussIntegrator -import space.kscience.kmath.integration.UnivariateIntegrand +import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.value import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.nd @@ -10,11 +9,17 @@ import space.kscience.kmath.operations.invoke fun main(): Unit = DoubleField { nd(2, 2) { + + //Produce a diagonal StructureND + fun diagonal(v: Double) = produce { (i, j) -> + if (i == j) v else 0.0 + } + //Define a function in a nd space - val function: UnivariateFunction> = { x -> 3 * x.pow(2) + 2 * x + 1 } + val function: (Double) -> StructureND = { x: Double -> 3 * number(x).pow(2) + 2 * diagonal(x) + 1 } //get the result of the integration - val result: UnivariateIntegrand> = GaussIntegrator.legendre(this, 0.0..10.0, function = function) + val result = integrate(0.0..10.0, function = function) //the value is nullable because in some cases the integration could not succeed println(result.value) diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt index c4b9c572f..bc23e2f1b 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegrator.kt @@ -5,7 +5,6 @@ package space.kscience.kmath.integration import space.kscience.kmath.misc.UnstableKMathAPI -import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.Field import space.kscience.kmath.structures.* @@ -15,12 +14,10 @@ import space.kscience.kmath.structures.* */ public class GaussIntegrator( public val algebra: Field, - public val bufferFactory: BufferFactory, ) : UnivariateIntegrator { - private fun buildRule(integrand: UnivariateIntegrand): Pair, Buffer> { - val factory = integrand.getFeature>() - ?: GenericGaussLegendreRuleFactory(algebra, bufferFactory) + private fun buildRule(integrand: UnivariateIntegrand): Pair, Buffer> { + val factory = integrand.getFeature() ?: GaussLegendreRuleFactory val numPoints = integrand.getFeature()?.maxCalls ?: 100 val range = integrand.getFeature>()?.range ?: 0.0..1.0 return factory.build(numPoints, range) @@ -32,9 +29,9 @@ public class GaussIntegrator( var res = zero var c = zero for (i in points.indices) { - val x: T = points[i] - val w: T = weights[i] - val y: T = w * f(x) - c + val x = points[i] + val weight = weights[i] + val y: T = weight * f(x) - c val t = res + y c = t - res - y res = t @@ -44,68 +41,38 @@ public class GaussIntegrator( public companion object { - /** - * Integrate [T]-valued univariate function using provided set of [IntegrandFeature] - * Following features are evaluated: - * * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GenericGaussLegendreRuleFactory] - * * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. - * * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points. - */ - public fun integrate( - algebra: Field, - bufferFactory: BufferFactory = ::ListBuffer, - vararg features: IntegrandFeature, - function: (T) -> T, - ): UnivariateIntegrand = - GaussIntegrator(algebra, bufferFactory).integrate(UnivariateIntegrand(function, *features)) - - /** - * Integrate in real numbers - */ - public fun integrate( - vararg features: IntegrandFeature, - function: (Double) -> Double, - ): UnivariateIntegrand = integrate(DoubleField, ::DoubleBuffer, features = features, function) - - /** - * Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points. - * The [range] is automatically transformed into [T] using [Field.number] - */ - @UnstableKMathAPI - public fun legendre( - algebra: Field, - range: ClosedRange, - numPoints: Int = 100, - bufferFactory: BufferFactory = ::ListBuffer, - vararg features: IntegrandFeature, - function: (T) -> T, - ): UnivariateIntegrand = GaussIntegrator(algebra, bufferFactory).integrate( - UnivariateIntegrand( - function, - IntegrationRange(range), - DoubleGaussLegendreRuleFactory, - IntegrandMaxCalls(numPoints), - *features - ) - ) - - /** - * Integrate given [function] in a [range] with Gauss-Legendre quadrature with [numPoints] points. - */ - @UnstableKMathAPI - public fun legendre( - range: ClosedRange, - numPoints: Int = 100, - vararg features: IntegrandFeature, - function: (Double) -> Double, - ): UnivariateIntegrand = GaussIntegrator(DoubleField, ::DoubleBuffer).integrate( - UnivariateIntegrand( - function, - IntegrationRange(range), - DoubleGaussLegendreRuleFactory, - IntegrandMaxCalls(numPoints), - *features - ) - ) } -} \ No newline at end of file +} + +/** + * Integrate [T]-valued univariate function using provided set of [IntegrandFeature] + * Following features are evaluated: + * * [GaussIntegratorRuleFactory] - A factory for computing the Gauss integration rule. By default uses [GaussLegendreRuleFactory] + * * [IntegrationRange] - the univariate range of integration. By default uses 0..1 interval. + * * [IntegrandMaxCalls] - the maximum number of function calls during integration. For non-iterative rules, always uses the maximum number of points. By default uses 100 points. + */ +@UnstableKMathAPI +public fun Field.integrate( + vararg features: IntegrandFeature, + function: (Double) -> T, +): UnivariateIntegrand = GaussIntegrator(this).integrate(UnivariateIntegrand(function, *features)) + + +/** + * Use [GaussIntegrator.Companion.integrate] to integrate the function in the current algebra with given [range] and [numPoints] + */ +@UnstableKMathAPI +public fun Field.integrate( + range: ClosedRange, + numPoints: Int = 100, + vararg features: IntegrandFeature, + function: (Double) -> T, +): UnivariateIntegrand = GaussIntegrator(this).integrate( + UnivariateIntegrand( + function, + IntegrationRange(range), + GaussLegendreRuleFactory, + IntegrandMaxCalls(numPoints), + *features + ) +) \ No newline at end of file diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt index 8e961cc62..133f829e3 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/GaussIntegratorRuleFactory.kt @@ -5,22 +5,20 @@ package space.kscience.kmath.integration -import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.Field -import space.kscience.kmath.structures.* +import space.kscience.kmath.structures.Buffer +import space.kscience.kmath.structures.DoubleBuffer +import space.kscience.kmath.structures.asBuffer +import space.kscience.kmath.structures.map import kotlin.jvm.Synchronized import kotlin.math.ulp import kotlin.native.concurrent.ThreadLocal -public interface GaussIntegratorRuleFactory : IntegrandFeature { - public val algebra: Field - public val bufferFactory: BufferFactory - - public fun build(numPoints: Int): Pair, Buffer> +public interface GaussIntegratorRuleFactory : IntegrandFeature { + public fun build(numPoints: Int): Pair, Buffer> public companion object { public fun double(numPoints: Int, range: ClosedRange): Pair, Buffer> = - DoubleGaussLegendreRuleFactory.build(numPoints, range) + GaussLegendreRuleFactory.build(numPoints, range) } } @@ -28,24 +26,23 @@ public interface GaussIntegratorRuleFactory : IntegrandFeature { * Create an integration rule by scaling existing normalized rule * */ -public fun GaussIntegratorRuleFactory.build( +public fun GaussIntegratorRuleFactory.build( numPoints: Int, range: ClosedRange, -): Pair, Buffer> { +): Pair, Buffer> { val normalized = build(numPoints) - with(algebra) { - val length = range.endInclusive - range.start + val length = range.endInclusive - range.start - val points = normalized.first.map(bufferFactory) { - number(range.start + length / 2) + number(length / 2) * it - } - - val weights = normalized.second.map(bufferFactory) { - it * length / 2 - } - - return points to weights + val points = normalized.first.map(::DoubleBuffer) { + range.start + length / 2 + length / 2 * it } + + val weights = normalized.second.map(::DoubleBuffer) { + it * length / 2 + } + + return points to weights + } @@ -56,11 +53,7 @@ public fun GaussIntegratorRuleFactory.build( * */ @ThreadLocal -public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory { - - override val algebra: Field get() = DoubleField - - override val bufferFactory: BufferFactory get() = ::DoubleBuffer +public object GaussLegendreRuleFactory : GaussIntegratorRuleFactory { private val cache = HashMap, Buffer>>() @@ -171,22 +164,4 @@ public object DoubleGaussLegendreRuleFactory : GaussIntegratorRuleFactory, Buffer> = getOrBuildRule(numPoints) -} - - -/** - * A generic Gauss-Legendre rule factory that wraps [DoubleGaussLegendreRuleFactory] in a generic way. - */ -public class GenericGaussLegendreRuleFactory( - override val algebra: Field, - override val bufferFactory: BufferFactory, -) : GaussIntegratorRuleFactory { - - override fun build(numPoints: Int): Pair, Buffer> { - val (doublePoints, doubleWeights) = DoubleGaussLegendreRuleFactory.build(numPoints) - - val points = doublePoints.map(bufferFactory) { algebra.number(it) } - val weights = doubleWeights.map(bufferFactory) { algebra.number(it) } - return points to weights - } -} +} \ No newline at end of file diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt index e49e83845..0b41a3f8b 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/integration/UnivariateIntegrand.kt @@ -5,14 +5,13 @@ package space.kscience.kmath.integration -import space.kscience.kmath.functions.UnivariateFunction import space.kscience.kmath.misc.UnstableKMathAPI import kotlin.jvm.JvmInline import kotlin.reflect.KClass public class UnivariateIntegrand internal constructor( private val features: Map, IntegrandFeature>, - public val function: UnivariateFunction, + public val function: (Double) -> T, ) : Integrand { @Suppress("UNCHECKED_CAST") @@ -27,7 +26,7 @@ public class UnivariateIntegrand internal constructor( @Suppress("FunctionName") public fun UnivariateIntegrand( - function: (T) -> T, + function: (Double) -> T, vararg features: IntegrandFeature, ): UnivariateIntegrand = UnivariateIntegrand(features.associateBy { it::class }, function) diff --git a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt index 247318367..5ec90f42a 100644 --- a/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt +++ b/kmath-functions/src/commonTest/kotlin/space/kscience/kmath/integration/GaussIntegralTest.kt @@ -5,15 +5,18 @@ package space.kscience.kmath.integration +import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.operations.DoubleField import kotlin.math.PI import kotlin.math.sin import kotlin.test.Test import kotlin.test.assertEquals +@OptIn(UnstableKMathAPI::class) class GaussIntegralTest { @Test fun gaussSin() { - val res = GaussIntegrator.legendre(0.0..2 * PI) { x -> + val res = DoubleField.integrate(0.0..2 * PI) { x -> sin(x) } assertEquals(0.0, res.value!!, 1e-4) @@ -21,7 +24,7 @@ class GaussIntegralTest { @Test fun gaussUniform() { - val res = GaussIntegrator.legendre(0.0..100.0,300) { x -> + val res = DoubleField.integrate(0.0..100.0,300) { x -> if(x in 30.0..50.0){ 1.0 } else {