forked from kscience/kmath
Interpolation API
This commit is contained in:
parent
d56b4148be
commit
73f40105c4
@ -8,5 +8,6 @@ dependencies {
|
|||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
api(project(":kmath-prob"))
|
api(project(":kmath-prob"))
|
||||||
|
api(project(":kmath-functions"))
|
||||||
api("org.apache.commons:commons-math3:3.6.1")
|
api("org.apache.commons:commons-math3:3.6.1")
|
||||||
}
|
}
|
@ -73,6 +73,8 @@ fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
|
|||||||
|
|
||||||
fun <T> Buffer<T>.asIterable(): Iterable<T> = asSequence().asIterable()
|
fun <T> Buffer<T>.asIterable(): Iterable<T> = asSequence().asIterable()
|
||||||
|
|
||||||
|
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)
|
||||||
|
|
||||||
interface MutableBuffer<T> : Buffer<T> {
|
interface MutableBuffer<T> : Buffer<T> {
|
||||||
operator fun set(index: Int, value: T)
|
operator fun set(index: Int, value: T)
|
||||||
|
|
||||||
|
@ -6,14 +6,16 @@ interface Piecewise<T, R> {
|
|||||||
fun findPiece(arg: T): R?
|
fun findPiece(arg: T): R?
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PiecewisePolynomial<T : Any> : Piecewise<T, Polynomial<T>>
|
interface PiecewisePolynomial<T : Any> :
|
||||||
|
Piecewise<T, Polynomial<T>>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ordered list of pieces in piecewise function
|
* Ordered list of pieces in piecewise function
|
||||||
*/
|
*/
|
||||||
class OrderedPiecewisePolynomial<T : Comparable<T>>(left: T) : PiecewisePolynomial<T> {
|
class OrderedPiecewisePolynomial<T : Comparable<T>>(delimeter: T) :
|
||||||
|
PiecewisePolynomial<T> {
|
||||||
|
|
||||||
private val delimiters: ArrayList<T> = arrayListOf(left)
|
private val delimiters: ArrayList<T> = arrayListOf(delimeter)
|
||||||
private val pieces: ArrayList<Polynomial<T>> = ArrayList()
|
private val pieces: ArrayList<Polynomial<T>> = ArrayList()
|
||||||
|
|
||||||
/**
|
/**
|
@ -32,7 +32,8 @@ fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring.run {
|
|||||||
/**
|
/**
|
||||||
* Represent a polynomial as a context-dependent function
|
* Represent a polynomial as a context-dependent function
|
||||||
*/
|
*/
|
||||||
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object : MathFunction<T, C, T> {
|
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object :
|
||||||
|
MathFunction<T, C, T> {
|
||||||
override fun C.invoke(arg: T): T = value(this, arg)
|
override fun C.invoke(arg: T): T = value(this, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,7 +62,8 @@ class PolynomialSpace<T : Any, C : Ring<T>>(val ring: C) : Space<Polynomial<T>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: Polynomial<T> = Polynomial(emptyList())
|
override val zero: Polynomial<T> =
|
||||||
|
Polynomial(emptyList())
|
||||||
|
|
||||||
operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
|
operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
|
||||||
}
|
}
|
@ -5,7 +5,7 @@ import scientifik.kmath.functions.value
|
|||||||
import scientifik.kmath.operations.Ring
|
import scientifik.kmath.operations.Ring
|
||||||
|
|
||||||
interface Interpolator<X, Y> {
|
interface Interpolator<X, Y> {
|
||||||
fun interpolate(points: Collection<Pair<X, Y>>): (X) -> Y
|
fun interpolate(points: XYPointSet<X, Y>): (X) -> Y
|
||||||
}
|
}
|
||||||
|
|
||||||
interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T> {
|
interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T> {
|
||||||
@ -13,9 +13,9 @@ interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T> {
|
|||||||
|
|
||||||
fun getDefaultValue(): T = error("Out of bounds")
|
fun getDefaultValue(): T = error("Out of bounds")
|
||||||
|
|
||||||
fun interpolatePolynomials(points: Collection<Pair<T, T>>): PiecewisePolynomial<T>
|
fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T>
|
||||||
|
|
||||||
override fun interpolate(points: Collection<Pair<T, T>>): (T) -> T = { x ->
|
override fun interpolate(points: XYPointSet<T, T>): (T) -> T = { x ->
|
||||||
interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue()
|
interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue()
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -10,18 +10,16 @@ import scientifik.kmath.operations.Field
|
|||||||
*/
|
*/
|
||||||
class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
|
||||||
|
|
||||||
override fun interpolatePolynomials(points: Collection<Pair<T, T>>): PiecewisePolynomial<T> = algebra.run {
|
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
|
||||||
require(points.isNotEmpty()) { "Point array should not be empty" }
|
require(points.size > 0) { "Point array should not be empty" }
|
||||||
|
insureSorted(points)
|
||||||
|
|
||||||
//sorting points
|
OrderedPiecewisePolynomial(points.x[0]).apply {
|
||||||
val sorted = points.sortedBy { it.first }
|
|
||||||
|
|
||||||
return@run OrderedPiecewisePolynomial(points.first().first).apply {
|
|
||||||
for (i in 0 until points.size - 1) {
|
for (i in 0 until points.size - 1) {
|
||||||
val slope = (sorted[i + 1].second - sorted[i].second) / (sorted[i + 1].first - sorted[i].first)
|
val slope = (points.y[i + 1] - points.y[i]) / (points.x[i + 1] - points.x[i])
|
||||||
val const = sorted[i].second - slope * sorted[i].first
|
val const = points.x[i] - slope * points.x[i]
|
||||||
val polynomial = Polynomial(const, slope)
|
val polynomial = Polynomial(const, slope)
|
||||||
putRight(sorted[i + 1].first, polynomial)
|
putRight(points.x[i + 1], polynomial)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,296 @@
|
|||||||
|
//package scientifik.kmath.interpolation
|
||||||
|
//
|
||||||
|
//import scientifik.kmath.functions.PiecewisePolynomial
|
||||||
|
//import scientifik.kmath.operations.Ring
|
||||||
|
//import scientifik.kmath.structures.Buffer
|
||||||
|
//import kotlin.math.abs
|
||||||
|
//import kotlin.math.sqrt
|
||||||
|
//
|
||||||
|
//
|
||||||
|
///**
|
||||||
|
// * Original code: https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/LoessInterpolator.java
|
||||||
|
// */
|
||||||
|
//class LoessInterpolator<T : Comparable<T>>(override val algebra: Ring<T>) : PolynomialInterpolator<T> {
|
||||||
|
// /**
|
||||||
|
// * The bandwidth parameter: when computing the loess fit at
|
||||||
|
// * a particular point, this fraction of source points closest
|
||||||
|
// * to the current point is taken into account for computing
|
||||||
|
// * a least-squares regression.
|
||||||
|
// *
|
||||||
|
// *
|
||||||
|
// * A sensible value is usually 0.25 to 0.5.
|
||||||
|
// */
|
||||||
|
// private var bandwidth = 0.0
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * The number of robustness iterations parameter: this many
|
||||||
|
// * robustness iterations are done.
|
||||||
|
// *
|
||||||
|
// *
|
||||||
|
// * A sensible value is usually 0 (just the initial fit without any
|
||||||
|
// * robustness iterations) to 4.
|
||||||
|
// */
|
||||||
|
// private var robustnessIters = 0
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * If the median residual at a certain robustness iteration
|
||||||
|
// * is less than this amount, no more iterations are done.
|
||||||
|
// */
|
||||||
|
// private var accuracy = 0.0
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Constructs a new [LoessInterpolator]
|
||||||
|
// * with a bandwidth of [.DEFAULT_BANDWIDTH],
|
||||||
|
// * [.DEFAULT_ROBUSTNESS_ITERS] robustness iterations
|
||||||
|
// * and an accuracy of {#link #DEFAULT_ACCURACY}.
|
||||||
|
// * See [.LoessInterpolator] for an explanation of
|
||||||
|
// * the parameters.
|
||||||
|
// */
|
||||||
|
// fun LoessInterpolator() {
|
||||||
|
// bandwidth = DEFAULT_BANDWIDTH
|
||||||
|
// robustnessIters = DEFAULT_ROBUSTNESS_ITERS
|
||||||
|
// accuracy = DEFAULT_ACCURACY
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fun LoessInterpolator(bandwidth: Double, robustnessIters: Int) {
|
||||||
|
// this(bandwidth, robustnessIters, DEFAULT_ACCURACY)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fun LoessInterpolator(bandwidth: Double, robustnessIters: Int, accuracy: Double) {
|
||||||
|
// if (bandwidth < 0 ||
|
||||||
|
// bandwidth > 1
|
||||||
|
// ) {
|
||||||
|
// throw OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1)
|
||||||
|
// }
|
||||||
|
// this.bandwidth = bandwidth
|
||||||
|
// if (robustnessIters < 0) {
|
||||||
|
// throw NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters)
|
||||||
|
// }
|
||||||
|
// this.robustnessIters = robustnessIters
|
||||||
|
// this.accuracy = accuracy
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fun interpolate(
|
||||||
|
// xval: DoubleArray,
|
||||||
|
// yval: DoubleArray
|
||||||
|
// ): PolynomialSplineFunction {
|
||||||
|
// return SplineInterpolator().interpolate(xval, smooth(xval, yval))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fun XYZPointSet<Double, Double, Double>.smooth(): XYPointSet<Double, Double> {
|
||||||
|
// checkAllFiniteReal(x)
|
||||||
|
// checkAllFiniteReal(y)
|
||||||
|
// checkAllFiniteReal(z)
|
||||||
|
// MathArrays.checkOrder(xval)
|
||||||
|
// if (size == 1) {
|
||||||
|
// return doubleArrayOf(y[0])
|
||||||
|
// }
|
||||||
|
// if (size == 2) {
|
||||||
|
// return doubleArrayOf(y[0], y[1])
|
||||||
|
// }
|
||||||
|
// val bandwidthInPoints = (bandwidth * size).toInt()
|
||||||
|
// if (bandwidthInPoints < 2) {
|
||||||
|
// throw NumberIsTooSmallException(
|
||||||
|
// LocalizedFormats.BANDWIDTH,
|
||||||
|
// bandwidthInPoints, 2, true
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
// val res = DoubleArray(size)
|
||||||
|
// val residuals = DoubleArray(size)
|
||||||
|
// val sortedResiduals = DoubleArray(size)
|
||||||
|
// val robustnessWeights = DoubleArray(size)
|
||||||
|
// // Do an initial fit and 'robustnessIters' robustness iterations.
|
||||||
|
// // This is equivalent to doing 'robustnessIters+1' robustness iterations
|
||||||
|
// // starting with all robustness weights set to 1.
|
||||||
|
// Arrays.fill(robustnessWeights, 1.0)
|
||||||
|
// for (iter in 0..robustnessIters) {
|
||||||
|
// val bandwidthInterval = intArrayOf(0, bandwidthInPoints - 1)
|
||||||
|
// // At each x, compute a local weighted linear regression
|
||||||
|
// for (i in 0 until size) {
|
||||||
|
//// val x = x[i]
|
||||||
|
// // Find out the interval of source points on which
|
||||||
|
// // a regression is to be made.
|
||||||
|
// if (i > 0) {
|
||||||
|
// updateBandwidthInterval(x, z, i, bandwidthInterval)
|
||||||
|
// }
|
||||||
|
// val ileft = bandwidthInterval[0]
|
||||||
|
// val iright = bandwidthInterval[1]
|
||||||
|
// // Compute the point of the bandwidth interval that is
|
||||||
|
// // farthest from x
|
||||||
|
// val edge: Int
|
||||||
|
// edge = if (x[i] - x[ileft] > x[iright] - x[i]) {
|
||||||
|
// ileft
|
||||||
|
// } else {
|
||||||
|
// iright
|
||||||
|
// }
|
||||||
|
// // Compute a least-squares linear fit weighted by
|
||||||
|
// // the product of robustness weights and the tricube
|
||||||
|
// // weight function.
|
||||||
|
// // See http://en.wikipedia.org/wiki/Linear_regression
|
||||||
|
// // (section "Univariate linear case")
|
||||||
|
// // and http://en.wikipedia.org/wiki/Weighted_least_squares
|
||||||
|
// // (section "Weighted least squares")
|
||||||
|
// var sumWeights = 0.0
|
||||||
|
// var sumX = 0.0
|
||||||
|
// var sumXSquared = 0.0
|
||||||
|
// var sumY = 0.0
|
||||||
|
// var sumXY = 0.0
|
||||||
|
// val denom: Double = abs(1.0 / (x[edge] - x[i]))
|
||||||
|
// for (k in ileft..iright) {
|
||||||
|
// val xk = x[k]
|
||||||
|
// val yk = y[k]
|
||||||
|
// val dist = if (k < i) x - xk else xk - x[i]
|
||||||
|
// val w = tricube(dist * denom) * robustnessWeights[k] * z[k]
|
||||||
|
// val xkw = xk * w
|
||||||
|
// sumWeights += w
|
||||||
|
// sumX += xkw
|
||||||
|
// sumXSquared += xk * xkw
|
||||||
|
// sumY += yk * w
|
||||||
|
// sumXY += yk * xkw
|
||||||
|
// }
|
||||||
|
// val meanX = sumX / sumWeights
|
||||||
|
// val meanY = sumY / sumWeights
|
||||||
|
// val meanXY = sumXY / sumWeights
|
||||||
|
// val meanXSquared = sumXSquared / sumWeights
|
||||||
|
// val beta: Double
|
||||||
|
// beta = if (sqrt(abs(meanXSquared - meanX * meanX)) < accuracy) {
|
||||||
|
// 0.0
|
||||||
|
// } else {
|
||||||
|
// (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX)
|
||||||
|
// }
|
||||||
|
// val alpha = meanY - beta * meanX
|
||||||
|
// res[i] = beta * x[i] + alpha
|
||||||
|
// residuals[i] = abs(y[i] - res[i])
|
||||||
|
// }
|
||||||
|
// // No need to recompute the robustness weights at the last
|
||||||
|
// // iteration, they won't be needed anymore
|
||||||
|
// if (iter == robustnessIters) {
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// // Recompute the robustness weights.
|
||||||
|
// // Find the median residual.
|
||||||
|
// // An arraycopy and a sort are completely tractable here,
|
||||||
|
// // because the preceding loop is a lot more expensive
|
||||||
|
// java.lang.System.arraycopy(residuals, 0, sortedResiduals, 0, size)
|
||||||
|
// Arrays.sort(sortedResiduals)
|
||||||
|
// val medianResidual = sortedResiduals[size / 2]
|
||||||
|
// if (abs(medianResidual) < accuracy) {
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// for (i in 0 until size) {
|
||||||
|
// val arg = residuals[i] / (6 * medianResidual)
|
||||||
|
// if (arg >= 1) {
|
||||||
|
// robustnessWeights[i] = 0.0
|
||||||
|
// } else {
|
||||||
|
// val w = 1 - arg * arg
|
||||||
|
// robustnessWeights[i] = w * w
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return res
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fun smooth(xval: DoubleArray, yval: DoubleArray): DoubleArray {
|
||||||
|
// if (xval.size != yval.size) {
|
||||||
|
// throw DimensionMismatchException(xval.size, yval.size)
|
||||||
|
// }
|
||||||
|
// val unitWeights = DoubleArray(xval.size)
|
||||||
|
// Arrays.fill(unitWeights, 1.0)
|
||||||
|
// return smooth(xval, yval, unitWeights)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Given an index interval into xval that embraces a certain number of
|
||||||
|
// * points closest to `xval[i-1]`, update the interval so that it
|
||||||
|
// * embraces the same number of points closest to `xval[i]`,
|
||||||
|
// * ignoring zero weights.
|
||||||
|
// *
|
||||||
|
// * @param xval Arguments array.
|
||||||
|
// * @param weights Weights array.
|
||||||
|
// * @param i Index around which the new interval should be computed.
|
||||||
|
// * @param bandwidthInterval a two-element array {left, right} such that:
|
||||||
|
// * `(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])`
|
||||||
|
// * and
|
||||||
|
// * `(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])`.
|
||||||
|
// * The array will be updated.
|
||||||
|
// */
|
||||||
|
// private fun updateBandwidthInterval(
|
||||||
|
// xval: Buffer<Double>, weights: Buffer<Double>,
|
||||||
|
// i: Int,
|
||||||
|
// bandwidthInterval: IntArray
|
||||||
|
// ) {
|
||||||
|
// val left = bandwidthInterval[0]
|
||||||
|
// val right = bandwidthInterval[1]
|
||||||
|
// // The right edge should be adjusted if the next point to the right
|
||||||
|
// // is closer to xval[i] than the leftmost point of the current interval
|
||||||
|
// val nextRight = nextNonzero(weights, right)
|
||||||
|
// if (nextRight < xval.size && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
|
||||||
|
// val nextLeft = nextNonzero(weights, bandwidthInterval[0])
|
||||||
|
// bandwidthInterval[0] = nextLeft
|
||||||
|
// bandwidthInterval[1] = nextRight
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Return the smallest index `j` such that
|
||||||
|
// * `j > i && (j == weights.length || weights[j] != 0)`.
|
||||||
|
// *
|
||||||
|
// * @param weights Weights array.
|
||||||
|
// * @param i Index from which to start search.
|
||||||
|
// * @return the smallest compliant index.
|
||||||
|
// */
|
||||||
|
// private fun nextNonzero(weights: Buffer<Double>, i: Int): Int {
|
||||||
|
// var j = i + 1
|
||||||
|
// while (j < weights.size && weights[j] == 0.0) {
|
||||||
|
// ++j
|
||||||
|
// }
|
||||||
|
// return j
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Compute the
|
||||||
|
// * [tricube](http://en.wikipedia.org/wiki/Local_regression#Weight_function)
|
||||||
|
// * weight function
|
||||||
|
// *
|
||||||
|
// * @param x Argument.
|
||||||
|
// * @return `(1 - |x|<sup>3</sup>)<sup>3</sup>` for |x| < 1, 0 otherwise.
|
||||||
|
// */
|
||||||
|
// private fun tricube(x: Double): Double {
|
||||||
|
// val absX: Double = FastMath.abs(x)
|
||||||
|
// if (absX >= 1.0) {
|
||||||
|
// return 0.0
|
||||||
|
// }
|
||||||
|
// val tmp = 1 - absX * absX * absX
|
||||||
|
// return tmp * tmp * tmp
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Check that all elements of an array are finite real numbers.
|
||||||
|
// *
|
||||||
|
// * @param values Values array.
|
||||||
|
// * @throws org.apache.commons.math4.exception.NotFiniteNumberException
|
||||||
|
// * if one of the values is not a finite real number.
|
||||||
|
// */
|
||||||
|
// private fun checkAllFiniteReal(values: DoubleArray) {
|
||||||
|
// for (i in values.indices) {
|
||||||
|
// MathUtils.checkFinite(values[i])
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// override fun interpolatePolynomials(points: Collection<Pair<T, T>>): PiecewisePolynomial<T> {
|
||||||
|
// TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// companion object {
|
||||||
|
// /** Default value of the bandwidth parameter. */
|
||||||
|
// const val DEFAULT_BANDWIDTH = 0.3
|
||||||
|
//
|
||||||
|
// /** Default value of the number of robustness iterations. */
|
||||||
|
// const val DEFAULT_ROBUSTNESS_ITERS = 2
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Default value for accuracy.
|
||||||
|
// */
|
||||||
|
// const val DEFAULT_ACCURACY = 1e-12
|
||||||
|
// }
|
||||||
|
//}
|
@ -0,0 +1,58 @@
|
|||||||
|
package scientifik.kmath.interpolation
|
||||||
|
|
||||||
|
import scientifik.kmath.functions.OrderedPiecewisePolynomial
|
||||||
|
import scientifik.kmath.functions.PiecewisePolynomial
|
||||||
|
import scientifik.kmath.functions.Polynomial
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.structures.MutableBufferFactory
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generic spline interpolator. Not recommended for performance critical places, use platform-specific and type specific ones.
|
||||||
|
* Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java
|
||||||
|
*/
|
||||||
|
class SplineInterpolator<T : Comparable<T>>(
|
||||||
|
override val algebra: Field<T>,
|
||||||
|
val bufferFactory: MutableBufferFactory<T>
|
||||||
|
) : PolynomialInterpolator<T> {
|
||||||
|
|
||||||
|
//TODO possibly optimize zeroed buffers
|
||||||
|
|
||||||
|
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
|
||||||
|
if (points.size < 3) {
|
||||||
|
error("Can't use spline interpolator with less than 3 points")
|
||||||
|
}
|
||||||
|
insureSorted(points)
|
||||||
|
|
||||||
|
// Number of intervals. The number of data points is n + 1.
|
||||||
|
val n = points.size - 1
|
||||||
|
// Differences between knot points
|
||||||
|
val h = bufferFactory(points.size) { i -> points.x[i + 1] - points.x[i] }
|
||||||
|
val mu = bufferFactory(points.size - 1) { zero }
|
||||||
|
val z = bufferFactory(points.size) { zero }
|
||||||
|
|
||||||
|
for (i in 1 until n) {
|
||||||
|
val g = 2.0 * (points.x[i + 1] - points.x[i - 1]) - h[i - 1] * mu[i - 1]
|
||||||
|
mu[i] = h[i] / g
|
||||||
|
z[i] =
|
||||||
|
(3.0 * (points.y[i + 1] * h[i - 1] - points.x[i] * (points.x[i + 1] - points.x[i - 1]) + points.y[i - 1] * h[i]) / (h[i - 1] * h[i])
|
||||||
|
- h[i - 1] * z[i - 1]) / g
|
||||||
|
}
|
||||||
|
|
||||||
|
// cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants)
|
||||||
|
|
||||||
|
OrderedPiecewisePolynomial<T>(points.x[points.size - 1]).apply {
|
||||||
|
var cOld = zero
|
||||||
|
for (j in n - 1 downTo 0) {
|
||||||
|
val c = z[j] - mu[j] * cOld
|
||||||
|
val a = points.y[j]
|
||||||
|
val b = (points.y[j + 1] - points.y[j]) / h[j] - h[j] * (cOld + 2.0 * c) / 3.0
|
||||||
|
val d = (cOld - c) / (3.0 * h[j])
|
||||||
|
val polynomial = Polynomial(a, b, c, d)
|
||||||
|
cOld = c
|
||||||
|
putLeft(points.x[j], polynomial)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,54 @@
|
|||||||
|
package scientifik.kmath.interpolation
|
||||||
|
|
||||||
|
import scientifik.kmath.structures.Buffer
|
||||||
|
import scientifik.kmath.structures.Structure2D
|
||||||
|
|
||||||
|
interface XYPointSet<X, Y> {
|
||||||
|
val size: Int
|
||||||
|
val x: Buffer<X>
|
||||||
|
val y: Buffer<Y>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
|
||||||
|
val z: Buffer<Z>
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
|
||||||
|
for (i in 0 until points.size - 1) {
|
||||||
|
if (points.x[i + 1] <= points.x[i]) error("Input data is not sorted at index $i")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buffer<T> {
|
||||||
|
init {
|
||||||
|
require(column < structure.colNum) { "Column index is outside of structure column range" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override val size: Int get() = structure.rowNum
|
||||||
|
|
||||||
|
override fun get(index: Int): T = structure[index, column]
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<T> = sequence {
|
||||||
|
repeat(size) {
|
||||||
|
yield(get(it))
|
||||||
|
}
|
||||||
|
}.iterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
class BufferXYPointSet<X, Y>(override val x: Buffer<X>, override val y: Buffer<Y>) : XYPointSet<X, Y> {
|
||||||
|
init {
|
||||||
|
require(x.size == y.size) { "Sizes of x and y buffers should be the same" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override val size: Int
|
||||||
|
get() = x.size
|
||||||
|
}
|
||||||
|
|
||||||
|
fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T> {
|
||||||
|
require(shape[1] == 2) { "Structure second dimension should be of size 2" }
|
||||||
|
return object : XYPointSet<T, T> {
|
||||||
|
override val size: Int get() = this@asXYPointSet.shape[0]
|
||||||
|
override val x: Buffer<T> get() = NDStructureColumn(this@asXYPointSet, 0)
|
||||||
|
override val y: Buffer<T> get() = NDStructureColumn(this@asXYPointSet, 1)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user