forked from kscience/kmath
WIP Integrator tests
This commit is contained in:
parent
e2ceb64d36
commit
93bc371622
@ -0,0 +1,7 @@
|
|||||||
|
package space.kscience.kmath.functions
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
|
||||||
|
public typealias UnivariateFunction<T> = (T) -> T
|
||||||
|
|
||||||
|
public typealias MultivariateFunction<T> = (Buffer<T>) -> T
|
@ -23,7 +23,7 @@ import space.kscience.kmath.structures.indices
|
|||||||
/**
|
/**
|
||||||
* A simple one-pass integrator based on Gauss rule
|
* A simple one-pass integrator based on Gauss rule
|
||||||
*/
|
*/
|
||||||
public class GaussIntegrator<T : Any> internal constructor(
|
public class GaussIntegrator<T : Comparable<T>> internal constructor(
|
||||||
public val algebra: Ring<T>,
|
public val algebra: Ring<T>,
|
||||||
private val points: Buffer<T>,
|
private val points: Buffer<T>,
|
||||||
private val weights: Buffer<T>,
|
private val weights: Buffer<T>,
|
||||||
@ -31,6 +31,7 @@ public class GaussIntegrator<T : Any> internal constructor(
|
|||||||
|
|
||||||
init {
|
init {
|
||||||
require(points.size == weights.size) { "Inconsistent points and weights sizes" }
|
require(points.size == weights.size) { "Inconsistent points and weights sizes" }
|
||||||
|
require(points.indices.all { i -> i == 0 || points[i] > points[i - 1] }){"Integration nodes must be sorted"}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun integrate(integrand: UnivariateIntegrand<T>): UnivariateIntegrand<T> = with(algebra) {
|
override fun integrate(integrand: UnivariateIntegrand<T>): UnivariateIntegrand<T> = with(algebra) {
|
||||||
@ -54,12 +55,13 @@ public class GaussIntegrator<T : Any> internal constructor(
|
|||||||
range: ClosedRange<Double>,
|
range: ClosedRange<Double>,
|
||||||
numPoints: Int = 100,
|
numPoints: Int = 100,
|
||||||
ruleFactory: GaussIntegratorRuleFactory<Double> = GaussLegendreDoubleRuleFactory,
|
ruleFactory: GaussIntegratorRuleFactory<Double> = GaussLegendreDoubleRuleFactory,
|
||||||
|
features: List<IntegrandFeature> = emptyList(),
|
||||||
function: (Double) -> Double,
|
function: (Double) -> Double,
|
||||||
): Double {
|
): UnivariateIntegrand<Double> {
|
||||||
val (points, weights) = ruleFactory.build(numPoints, range)
|
val (points, weights) = ruleFactory.build(numPoints, range)
|
||||||
return GaussIntegrator(DoubleField, points, weights).integrate(
|
return GaussIntegrator(DoubleField, points, weights).integrate(
|
||||||
UnivariateIntegrand(function, IntegrationRange(range))
|
UnivariateIntegrand(function, IntegrationRange(range))
|
||||||
).value!!
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.Field
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
import kotlin.jvm.Synchronized
|
import kotlin.jvm.Synchronized
|
||||||
@ -8,7 +9,7 @@ import kotlin.math.ulp
|
|||||||
import kotlin.native.concurrent.ThreadLocal
|
import kotlin.native.concurrent.ThreadLocal
|
||||||
|
|
||||||
public interface GaussIntegratorRuleFactory<T : Any> {
|
public interface GaussIntegratorRuleFactory<T : Any> {
|
||||||
public val algebra: Ring<T>
|
public val algebra: Field<T>
|
||||||
public val bufferFactory: BufferFactory<T>
|
public val bufferFactory: BufferFactory<T>
|
||||||
public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>>
|
public fun build(numPoints: Int): Pair<Buffer<T>, Buffer<T>>
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ public fun <T : Comparable<T>> GaussIntegratorRuleFactory<T>.build(
|
|||||||
val points = with(algebra) {
|
val points = with(algebra) {
|
||||||
val length = range.endInclusive - range.start
|
val length = range.endInclusive - range.start
|
||||||
normalized.first.map(bufferFactory) {
|
normalized.first.map(bufferFactory) {
|
||||||
range.start + length * it
|
range.start + length / 2 + length * it/2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ public fun <T : Comparable<T>> GaussIntegratorRuleFactory<T>.build(
|
|||||||
@ThreadLocal
|
@ThreadLocal
|
||||||
public object GaussLegendreDoubleRuleFactory : GaussIntegratorRuleFactory<Double> {
|
public object GaussLegendreDoubleRuleFactory : GaussIntegratorRuleFactory<Double> {
|
||||||
|
|
||||||
override val algebra: Ring<Double> get() = DoubleField
|
override val algebra: Field<Double> get() = DoubleField
|
||||||
|
|
||||||
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
|
override val bufferFactory: BufferFactory<Double> get() = ::DoubleBuffer
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
package space.kscience.kmath.integration
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
|
import space.kscience.kmath.functions.UnivariateFunction
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import kotlin.jvm.JvmInline
|
import kotlin.jvm.JvmInline
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
public class UnivariateIntegrand<T : Any> internal constructor(
|
public class UnivariateIntegrand<T : Any> internal constructor(
|
||||||
private val features: Map<KClass<*>, IntegrandFeature>,
|
private val features: Map<KClass<*>, IntegrandFeature>,
|
||||||
public val function: (T) -> T,
|
public val function: UnivariateFunction<T>,
|
||||||
) : Integrand {
|
) : Integrand {
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
package space.kscience.kmath.integration
|
||||||
|
|
||||||
|
import kotlin.math.PI
|
||||||
|
import kotlin.math.sin
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class GaussIntegralTest {
|
||||||
|
@Test
|
||||||
|
fun gaussSin() {
|
||||||
|
val res = GaussIntegrator.integrate(0.0..2 * PI) { x ->
|
||||||
|
sin(x)
|
||||||
|
}
|
||||||
|
assertEquals(0.0, res.value!!, 1e-4)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun gaussUniform() {
|
||||||
|
val res = GaussIntegrator.integrate(0.0..100.0,300) { x ->
|
||||||
|
if(x in 30.0..50.0){
|
||||||
|
1.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assertEquals(20.0, res.value!!, 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user