WIP Integrator tests

This commit is contained in:
Alexander Nozik 2021-04-15 09:53:29 +03:00
parent e2ceb64d36
commit 93bc371622
5 changed files with 48 additions and 7 deletions

View File

@ -0,0 +1,7 @@
package space.kscience.kmath.functions
import space.kscience.kmath.structures.Buffer
public typealias UnivariateFunction<T> = (T) -> T
public typealias MultivariateFunction<T> = (Buffer<T>) -> T

View File

@ -23,7 +23,7 @@ import space.kscience.kmath.structures.indices
/** /**
* A simple one-pass integrator based on Gauss rule * A simple one-pass integrator based on Gauss rule
*/ */
public class GaussIntegrator<T : Any> internal constructor( public class GaussIntegrator<T : Comparable<T>> internal constructor(
public val algebra: Ring<T>, public val algebra: Ring<T>,
private val points: Buffer<T>, private val points: Buffer<T>,
private val weights: Buffer<T>, private val weights: Buffer<T>,
@ -31,6 +31,7 @@ public class GaussIntegrator<T : Any> internal constructor(
init { init {
require(points.size == weights.size) { "Inconsistent points and weights sizes" } require(points.size == weights.size) { "Inconsistent points and weights sizes" }
require(points.indices.all { i -> i == 0 || points[i] > points[i - 1] }){"Integration nodes must be sorted"}
} }
override fun integrate(integrand: UnivariateIntegrand<T>): UnivariateIntegrand<T> = with(algebra) { override fun integrate(integrand: UnivariateIntegrand<T>): UnivariateIntegrand<T> = with(algebra) {
@ -54,12 +55,13 @@ public class GaussIntegrator<T : Any> internal constructor(
range: ClosedRange<Double>, range: ClosedRange<Double>,
numPoints: Int = 100, numPoints: Int = 100,
ruleFactory: GaussIntegratorRuleFactory<Double> = GaussLegendreDoubleRuleFactory, ruleFactory: GaussIntegratorRuleFactory<Double> = GaussLegendreDoubleRuleFactory,
features: List<IntegrandFeature> = emptyList(),
function: (Double) -> Double, function: (Double) -> Double,
): Double { ): UnivariateIntegrand<Double> {
val (points, weights) = ruleFactory.build(numPoints, range) val (points, weights) = ruleFactory.build(numPoints, range)
return GaussIntegrator(DoubleField, points, weights).integrate( return GaussIntegrator(DoubleField, points, weights).integrate(
UnivariateIntegrand(function, IntegrationRange(range)) UnivariateIntegrand(function, IntegrationRange(range))
).value!! )
} }
} }
} }

View File

@ -1,6 +1,7 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Field
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import kotlin.jvm.Synchronized import kotlin.jvm.Synchronized
@ -8,7 +9,7 @@ import kotlin.math.ulp
import kotlin.native.concurrent.ThreadLocal import kotlin.native.concurrent.ThreadLocal
public interface GaussIntegratorRuleFactory<T : Any> { public interface GaussIntegratorRuleFactory<T : Any> {
public val algebra: Ring<T> public val algebra: Field<T>
public val bufferFactory: BufferFactory<T> public val bufferFactory: BufferFactory<T>
public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>> public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>>
@ -29,7 +30,7 @@ public fun <T : Comparable<T>> GaussIntegratorRuleFactory<T>.build(
val points = with(algebra) { val points = with(algebra) {
val length = range.endInclusive - range.start val length = range.endInclusive - range.start
normalized.first.map(bufferFactory) { normalized.first.map(bufferFactory) {
range.start + length * it range.start + length / 2 + length * it/2
} }
} }
@ -46,7 +47,7 @@ public fun <T : Comparable<T>> GaussIntegratorRuleFactory<T>.build(
@ThreadLocal @ThreadLocal
public object GaussLegendreDoubleRuleFactory : GaussIntegratorRuleFactory<Double> { public object GaussLegendreDoubleRuleFactory : GaussIntegratorRuleFactory<Double> {
override val algebra: Ring<Double> get() = DoubleField override val algebra: Field<Double> get() = DoubleField
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer

View File

@ -1,12 +1,13 @@
package space.kscience.kmath.integration package space.kscience.kmath.integration
import space.kscience.kmath.functions.UnivariateFunction
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import kotlin.jvm.JvmInline import kotlin.jvm.JvmInline
import kotlin.reflect.KClass import kotlin.reflect.KClass
public class UnivariateIntegrand<T : Any> internal constructor( public class UnivariateIntegrand<T : Any> internal constructor(
private val features: Map<KClass<*>, IntegrandFeature>, private val features: Map<KClass<*>, IntegrandFeature>,
public val function: (T) -> T, public val function: UnivariateFunction<T>,
) : Integrand { ) : Integrand {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")

View File

@ -0,0 +1,30 @@
package space.kscience.kmath.integration
import kotlin.math.PI
import kotlin.math.sin
import kotlin.test.Test
import kotlin.test.assertEquals
class GaussIntegralTest {
@Test
fun gaussSin() {
val res = GaussIntegrator.integrate(0.0..2 * PI) { x ->
sin(x)
}
assertEquals(0.0, res.value!!, 1e-4)
}
@Test
fun gaussUniform() {
val res = GaussIntegrator.integrate(0.0..100.0,300) { x ->
if(x in 30.0..50.0){
1.0
} else {
0.0
}
}
assertEquals(20.0, res.value!!, 0.1)
}
}