Interpolation API

This commit is contained in:
Alexander Nozik 2020-02-12 21:57:21 +03:00
parent d56b4148be
commit 73f40105c4
10 changed files with 430 additions and 17 deletions

View File

@ -8,5 +8,6 @@ dependencies {
api(project(":kmath-core"))
api(project(":kmath-coroutines"))
api(project(":kmath-prob"))
api(project(":kmath-functions"))
api("org.apache.commons:commons-math3:3.6.1")
}

View File

@ -73,6 +73,8 @@ fun <T> Buffer<T>.asSequence(): Sequence<T> = Sequence(::iterator)
fun <T> Buffer<T>.asIterable(): Iterable<T> = asSequence().asIterable()
val Buffer<*>.indices: IntRange get() = IntRange(0, size - 1)
interface MutableBuffer<T> : Buffer<T> {
operator fun set(index: Int, value: T)

View File

@ -6,14 +6,16 @@ interface Piecewise<T, R> {
fun findPiece(arg: T): R?
}
interface PiecewisePolynomial<T : Any> : Piecewise<T, Polynomial<T>>
interface PiecewisePolynomial<T : Any> :
Piecewise<T, Polynomial<T>>
/**
* Ordered list of pieces in piecewise function
*/
class OrderedPiecewisePolynomial<T : Comparable<T>>(left: T) : PiecewisePolynomial<T> {
class OrderedPiecewisePolynomial<T : Comparable<T>>(delimeter: T) :
PiecewisePolynomial<T> {
private val delimiters: ArrayList<T> = arrayListOf(left)
private val delimiters: ArrayList<T> = arrayListOf(delimeter)
private val pieces: ArrayList<Polynomial<T>> = ArrayList()
/**

View File

@ -32,7 +32,8 @@ fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring.run {
/**
* Represent a polynomial as a context-dependent function
*/
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object : MathFunction<T, C, T> {
fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = object :
MathFunction<T, C, T> {
override fun C.invoke(arg: T): T = value(this, arg)
}
@ -61,7 +62,8 @@ class PolynomialSpace<T : Any, C : Ring<T>>(val ring: C) : Space<Polynomial<T>>
}
}
override val zero: Polynomial<T> = Polynomial(emptyList())
override val zero: Polynomial<T> =
Polynomial(emptyList())
operator fun Polynomial<T>.invoke(arg: T): T = value(ring, arg)
}

View File

@ -5,7 +5,7 @@ import scientifik.kmath.functions.value
import scientifik.kmath.operations.Ring
interface Interpolator<X, Y> {
fun interpolate(points: Collection<Pair<X, Y>>): (X) -> Y
fun interpolate(points: XYPointSet<X, Y>): (X) -> Y
}
interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T> {
@ -13,9 +13,9 @@ interface PolynomialInterpolator<T : Comparable<T>> : Interpolator<T, T> {
fun getDefaultValue(): T = error("Out of bounds")
fun interpolatePolynomials(points: Collection<Pair<T, T>>): PiecewisePolynomial<T>
fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T>
override fun interpolate(points: Collection<Pair<T, T>>): (T) -> T = { x ->
override fun interpolate(points: XYPointSet<T, T>): (T) -> T = { x ->
interpolatePolynomials(points).value(algebra, x) ?: getDefaultValue()
}
}

View File

@ -10,18 +10,16 @@ import scientifik.kmath.operations.Field
*/
class LinearInterpolator<T : Comparable<T>>(override val algebra: Field<T>) : PolynomialInterpolator<T> {
override fun interpolatePolynomials(points: Collection<Pair<T, T>>): PiecewisePolynomial<T> = algebra.run {
require(points.isNotEmpty()) { "Point array should not be empty" }
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
require(points.size > 0) { "Point array should not be empty" }
insureSorted(points)
//sorting points
val sorted = points.sortedBy { it.first }
return@run OrderedPiecewisePolynomial(points.first().first).apply {
OrderedPiecewisePolynomial(points.x[0]).apply {
for (i in 0 until points.size - 1) {
val slope = (sorted[i + 1].second - sorted[i].second) / (sorted[i + 1].first - sorted[i].first)
val const = sorted[i].second - slope * sorted[i].first
val slope = (points.y[i + 1] - points.y[i]) / (points.x[i + 1] - points.x[i])
val const = points.x[i] - slope * points.x[i]
val polynomial = Polynomial(const, slope)
putRight(sorted[i + 1].first, polynomial)
putRight(points.x[i + 1], polynomial)
}
}
}

View File

@ -0,0 +1,296 @@
//package scientifik.kmath.interpolation
//
//import scientifik.kmath.functions.PiecewisePolynomial
//import scientifik.kmath.operations.Ring
//import scientifik.kmath.structures.Buffer
//import kotlin.math.abs
//import kotlin.math.sqrt
//
//
///**
// * Original code: https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/LoessInterpolator.java
// */
//class LoessInterpolator<T : Comparable<T>>(override val algebra: Ring<T>) : PolynomialInterpolator<T> {
// /**
// * The bandwidth parameter: when computing the loess fit at
// * a particular point, this fraction of source points closest
// * to the current point is taken into account for computing
// * a least-squares regression.
// *
// *
// * A sensible value is usually 0.25 to 0.5.
// */
// private var bandwidth = 0.0
//
// /**
// * The number of robustness iterations parameter: this many
// * robustness iterations are done.
// *
// *
// * A sensible value is usually 0 (just the initial fit without any
// * robustness iterations) to 4.
// */
// private var robustnessIters = 0
//
// /**
// * If the median residual at a certain robustness iteration
// * is less than this amount, no more iterations are done.
// */
// private var accuracy = 0.0
//
// /**
// * Constructs a new [LoessInterpolator]
// * with a bandwidth of [.DEFAULT_BANDWIDTH],
// * [.DEFAULT_ROBUSTNESS_ITERS] robustness iterations
// * and an accuracy of {#link #DEFAULT_ACCURACY}.
// * See [.LoessInterpolator] for an explanation of
// * the parameters.
// */
// fun LoessInterpolator() {
// bandwidth = DEFAULT_BANDWIDTH
// robustnessIters = DEFAULT_ROBUSTNESS_ITERS
// accuracy = DEFAULT_ACCURACY
// }
//
// fun LoessInterpolator(bandwidth: Double, robustnessIters: Int) {
// this(bandwidth, robustnessIters, DEFAULT_ACCURACY)
// }
//
// fun LoessInterpolator(bandwidth: Double, robustnessIters: Int, accuracy: Double) {
// if (bandwidth < 0 ||
// bandwidth > 1
// ) {
// throw OutOfRangeException(LocalizedFormats.BANDWIDTH, bandwidth, 0, 1)
// }
// this.bandwidth = bandwidth
// if (robustnessIters < 0) {
// throw NotPositiveException(LocalizedFormats.ROBUSTNESS_ITERATIONS, robustnessIters)
// }
// this.robustnessIters = robustnessIters
// this.accuracy = accuracy
// }
//
// fun interpolate(
// xval: DoubleArray,
// yval: DoubleArray
// ): PolynomialSplineFunction {
// return SplineInterpolator().interpolate(xval, smooth(xval, yval))
// }
//
// fun XYZPointSet<Double, Double, Double>.smooth(): XYPointSet<Double, Double> {
// checkAllFiniteReal(x)
// checkAllFiniteReal(y)
// checkAllFiniteReal(z)
// MathArrays.checkOrder(xval)
// if (size == 1) {
// return doubleArrayOf(y[0])
// }
// if (size == 2) {
// return doubleArrayOf(y[0], y[1])
// }
// val bandwidthInPoints = (bandwidth * size).toInt()
// if (bandwidthInPoints < 2) {
// throw NumberIsTooSmallException(
// LocalizedFormats.BANDWIDTH,
// bandwidthInPoints, 2, true
// )
// }
// val res = DoubleArray(size)
// val residuals = DoubleArray(size)
// val sortedResiduals = DoubleArray(size)
// val robustnessWeights = DoubleArray(size)
// // Do an initial fit and 'robustnessIters' robustness iterations.
// // This is equivalent to doing 'robustnessIters+1' robustness iterations
// // starting with all robustness weights set to 1.
// Arrays.fill(robustnessWeights, 1.0)
// for (iter in 0..robustnessIters) {
// val bandwidthInterval = intArrayOf(0, bandwidthInPoints - 1)
// // At each x, compute a local weighted linear regression
// for (i in 0 until size) {
//// val x = x[i]
// // Find out the interval of source points on which
// // a regression is to be made.
// if (i > 0) {
// updateBandwidthInterval(x, z, i, bandwidthInterval)
// }
// val ileft = bandwidthInterval[0]
// val iright = bandwidthInterval[1]
// // Compute the point of the bandwidth interval that is
// // farthest from x
// val edge: Int
// edge = if (x[i] - x[ileft] > x[iright] - x[i]) {
// ileft
// } else {
// iright
// }
// // Compute a least-squares linear fit weighted by
// // the product of robustness weights and the tricube
// // weight function.
// // See http://en.wikipedia.org/wiki/Linear_regression
// // (section "Univariate linear case")
// // and http://en.wikipedia.org/wiki/Weighted_least_squares
// // (section "Weighted least squares")
// var sumWeights = 0.0
// var sumX = 0.0
// var sumXSquared = 0.0
// var sumY = 0.0
// var sumXY = 0.0
// val denom: Double = abs(1.0 / (x[edge] - x[i]))
// for (k in ileft..iright) {
// val xk = x[k]
// val yk = y[k]
// val dist = if (k < i) x - xk else xk - x[i]
// val w = tricube(dist * denom) * robustnessWeights[k] * z[k]
// val xkw = xk * w
// sumWeights += w
// sumX += xkw
// sumXSquared += xk * xkw
// sumY += yk * w
// sumXY += yk * xkw
// }
// val meanX = sumX / sumWeights
// val meanY = sumY / sumWeights
// val meanXY = sumXY / sumWeights
// val meanXSquared = sumXSquared / sumWeights
// val beta: Double
// beta = if (sqrt(abs(meanXSquared - meanX * meanX)) < accuracy) {
// 0.0
// } else {
// (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX)
// }
// val alpha = meanY - beta * meanX
// res[i] = beta * x[i] + alpha
// residuals[i] = abs(y[i] - res[i])
// }
// // No need to recompute the robustness weights at the last
// // iteration, they won't be needed anymore
// if (iter == robustnessIters) {
// break
// }
// // Recompute the robustness weights.
// // Find the median residual.
// // An arraycopy and a sort are completely tractable here,
// // because the preceding loop is a lot more expensive
// java.lang.System.arraycopy(residuals, 0, sortedResiduals, 0, size)
// Arrays.sort(sortedResiduals)
// val medianResidual = sortedResiduals[size / 2]
// if (abs(medianResidual) < accuracy) {
// break
// }
// for (i in 0 until size) {
// val arg = residuals[i] / (6 * medianResidual)
// if (arg >= 1) {
// robustnessWeights[i] = 0.0
// } else {
// val w = 1 - arg * arg
// robustnessWeights[i] = w * w
// }
// }
// }
// return res
// }
//
// fun smooth(xval: DoubleArray, yval: DoubleArray): DoubleArray {
// if (xval.size != yval.size) {
// throw DimensionMismatchException(xval.size, yval.size)
// }
// val unitWeights = DoubleArray(xval.size)
// Arrays.fill(unitWeights, 1.0)
// return smooth(xval, yval, unitWeights)
// }
//
// /**
// * Given an index interval into xval that embraces a certain number of
// * points closest to `xval[i-1]`, update the interval so that it
// * embraces the same number of points closest to `xval[i]`,
// * ignoring zero weights.
// *
// * @param xval Arguments array.
// * @param weights Weights array.
// * @param i Index around which the new interval should be computed.
// * @param bandwidthInterval a two-element array {left, right} such that:
// * `(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])`
// * and
// * `(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])`.
// * The array will be updated.
// */
// private fun updateBandwidthInterval(
// xval: Buffer<Double>, weights: Buffer<Double>,
// i: Int,
// bandwidthInterval: IntArray
// ) {
// val left = bandwidthInterval[0]
// val right = bandwidthInterval[1]
// // The right edge should be adjusted if the next point to the right
// // is closer to xval[i] than the leftmost point of the current interval
// val nextRight = nextNonzero(weights, right)
// if (nextRight < xval.size && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
// val nextLeft = nextNonzero(weights, bandwidthInterval[0])
// bandwidthInterval[0] = nextLeft
// bandwidthInterval[1] = nextRight
// }
// }
//
// /**
// * Return the smallest index `j` such that
// * `j > i && (j == weights.length || weights[j] != 0)`.
// *
// * @param weights Weights array.
// * @param i Index from which to start search.
// * @return the smallest compliant index.
// */
// private fun nextNonzero(weights: Buffer<Double>, i: Int): Int {
// var j = i + 1
// while (j < weights.size && weights[j] == 0.0) {
// ++j
// }
// return j
// }
//
// /**
// * Compute the
// * [tricube](http://en.wikipedia.org/wiki/Local_regression#Weight_function)
// * weight function
// *
// * @param x Argument.
// * @return `(1 - |x|<sup>3</sup>)<sup>3</sup>` for |x| &lt; 1, 0 otherwise.
// */
// private fun tricube(x: Double): Double {
// val absX: Double = FastMath.abs(x)
// if (absX >= 1.0) {
// return 0.0
// }
// val tmp = 1 - absX * absX * absX
// return tmp * tmp * tmp
// }
//
// /**
// * Check that all elements of an array are finite real numbers.
// *
// * @param values Values array.
// * @throws org.apache.commons.math4.exception.NotFiniteNumberException
// * if one of the values is not a finite real number.
// */
// private fun checkAllFiniteReal(values: DoubleArray) {
// for (i in values.indices) {
// MathUtils.checkFinite(values[i])
// }
// }
//
// override fun interpolatePolynomials(points: Collection<Pair<T, T>>): PiecewisePolynomial<T> {
// TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
// }
//
// companion object {
// /** Default value of the bandwidth parameter. */
// const val DEFAULT_BANDWIDTH = 0.3
//
// /** Default value of the number of robustness iterations. */
// const val DEFAULT_ROBUSTNESS_ITERS = 2
//
// /**
// * Default value for accuracy.
// */
// const val DEFAULT_ACCURACY = 1e-12
// }
//}

View File

@ -0,0 +1,58 @@
package scientifik.kmath.interpolation
import scientifik.kmath.functions.OrderedPiecewisePolynomial
import scientifik.kmath.functions.PiecewisePolynomial
import scientifik.kmath.functions.Polynomial
import scientifik.kmath.operations.Field
import scientifik.kmath.structures.MutableBufferFactory
/**
* Generic spline interpolator. Not recommended for performance critical places, use platform-specific and type specific ones.
* Based on https://github.com/apache/commons-math/blob/eb57d6d457002a0bb5336d789a3381a24599affe/src/main/java/org/apache/commons/math4/analysis/interpolation/SplineInterpolator.java
*/
class SplineInterpolator<T : Comparable<T>>(
override val algebra: Field<T>,
val bufferFactory: MutableBufferFactory<T>
) : PolynomialInterpolator<T> {
//TODO possibly optimize zeroed buffers
override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra.run {
if (points.size < 3) {
error("Can't use spline interpolator with less than 3 points")
}
insureSorted(points)
// Number of intervals. The number of data points is n + 1.
val n = points.size - 1
// Differences between knot points
val h = bufferFactory(points.size) { i -> points.x[i + 1] - points.x[i] }
val mu = bufferFactory(points.size - 1) { zero }
val z = bufferFactory(points.size) { zero }
for (i in 1 until n) {
val g = 2.0 * (points.x[i + 1] - points.x[i - 1]) - h[i - 1] * mu[i - 1]
mu[i] = h[i] / g
z[i] =
(3.0 * (points.y[i + 1] * h[i - 1] - points.x[i] * (points.x[i + 1] - points.x[i - 1]) + points.y[i - 1] * h[i]) / (h[i - 1] * h[i])
- h[i - 1] * z[i - 1]) / g
}
// cubic spline coefficients -- b is linear, c quadratic, d is cubic (original y's are constants)
OrderedPiecewisePolynomial<T>(points.x[points.size - 1]).apply {
var cOld = zero
for (j in n - 1 downTo 0) {
val c = z[j] - mu[j] * cOld
val a = points.y[j]
val b = (points.y[j + 1] - points.y[j]) / h[j] - h[j] * (cOld + 2.0 * c) / 3.0
val d = (cOld - c) / (3.0 * h[j])
val polynomial = Polynomial(a, b, c, d)
cOld = c
putLeft(points.x[j], polynomial)
}
}
}
}

View File

@ -0,0 +1,54 @@
package scientifik.kmath.interpolation
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.Structure2D
interface XYPointSet<X, Y> {
val size: Int
val x: Buffer<X>
val y: Buffer<Y>
}
interface XYZPointSet<X, Y, Z> : XYPointSet<X, Y> {
val z: Buffer<Z>
}
internal fun <T : Comparable<T>> insureSorted(points: XYPointSet<T, *>) {
for (i in 0 until points.size - 1) {
if (points.x[i + 1] <= points.x[i]) error("Input data is not sorted at index $i")
}
}
class NDStructureColumn<T>(val structure: Structure2D<T>, val column: Int) : Buffer<T> {
init {
require(column < structure.colNum) { "Column index is outside of structure column range" }
}
override val size: Int get() = structure.rowNum
override fun get(index: Int): T = structure[index, column]
override fun iterator(): Iterator<T> = sequence {
repeat(size) {
yield(get(it))
}
}.iterator()
}
class BufferXYPointSet<X, Y>(override val x: Buffer<X>, override val y: Buffer<Y>) : XYPointSet<X, Y> {
init {
require(x.size == y.size) { "Sizes of x and y buffers should be the same" }
}
override val size: Int
get() = x.size
}
fun <T> Structure2D<T>.asXYPointSet(): XYPointSet<T, T> {
require(shape[1] == 2) { "Structure second dimension should be of size 2" }
return object : XYPointSet<T, T> {
override val size: Int get() = this@asXYPointSet.shape[0]
override val x: Buffer<T> get() = NDStructureColumn(this@asXYPointSet, 0)
override val y: Buffer<T> get() = NDStructureColumn(this@asXYPointSet, 1)
}
}