diff --git a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt index 220bde8ff..862ee6a60 100644 --- a/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt +++ b/kmath-functions/src/commonMain/kotlin/space/kscience/kmath/functions/Polynomial.kt @@ -6,6 +6,10 @@ package space.kscience.kmath.functions import space.kscience.kmath.operations.* +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract +import kotlin.experimental.ExperimentalTypeInference +import kotlin.jvm.JvmName import kotlin.math.max import kotlin.math.min @@ -124,8 +128,9 @@ public open class PolynomialSpace>( if (other == 0) zero else Polynomial( coefficients - .subList(0, degree + 1) - .map { it * other } + .applyAndRemoveZeros { + map { it * other } + } ) /** @@ -183,8 +188,9 @@ public open class PolynomialSpace>( if (this == 0) zero else Polynomial( other.coefficients - .subList(0, other.degree + 1) - .map { it * this } + .applyAndRemoveZeros { + map { it * this@times } + } ) /** @@ -238,8 +244,9 @@ public open class PolynomialSpace>( if (this.isZero()) other else Polynomial( other.coefficients - .subList(0, other.degree + 1) - .map { it * this } + .applyAndRemoveZeros { + map { it * this@times } + } ) /** @@ -291,8 +298,9 @@ public open class PolynomialSpace>( if (other.isZero()) this else Polynomial( coefficients - .subList(0, degree + 1) - .map { it * other } + .applyAndRemoveZeros { + map { it * other } + } ) /** @@ -303,33 +311,35 @@ public open class PolynomialSpace>( /** * Returns sum of the polynomials. */ - public override operator fun Polynomial.plus(other: Polynomial): Polynomial = - Polynomial( - (0..max(degree, other.degree)) - .map { - when { - it > degree -> other.coefficients[it] - it > other.degree -> coefficients[it] - else -> coefficients[it] + other.coefficients[it] - } + public override operator fun Polynomial.plus(other: Polynomial): Polynomial { + val thisDegree = degree + val otherDegree = other.degree + return Polynomial( + Coefficients(max(thisDegree, otherDegree) + 1) { + when { + it > thisDegree -> other.coefficients[it] + it > otherDegree -> coefficients[it] + else -> coefficients[it] + other.coefficients[it] } - .ifEmpty { listOf(constantZero) } + } ) + } /** * Returns difference of the polynomials. */ - public override operator fun Polynomial.minus(other: Polynomial): Polynomial = - Polynomial( - (0..max(degree, other.degree)) - .map { - when { - it > degree -> -other.coefficients[it] - it > other.degree -> coefficients[it] - else -> coefficients[it] - other.coefficients[it] - } + public override operator fun Polynomial.minus(other: Polynomial): Polynomial { + val thisDegree = degree + val otherDegree = other.degree + return Polynomial( + Coefficients(max(thisDegree, otherDegree) + 1) { + when { + it > thisDegree -> -other.coefficients[it] + it > otherDegree -> coefficients[it] + else -> coefficients[it] - other.coefficients[it] } - .ifEmpty { listOf(constantZero) } + } ) + } /** * Returns product of the polynomials. */ @@ -341,13 +351,11 @@ public open class PolynomialSpace>( otherDegree == -1 -> zero else -> Polynomial( - (0..(thisDegree + otherDegree)) - .map { d -> - (max(0, d - otherDegree)..min(thisDegree, d)) - .map { coefficients[it] * other.coefficients[d - it] } - .reduce { acc, rational -> acc + rational } - } - .run { subList(0, indexOfLast { it.isNotZero() } + 1) } + Coefficients(thisDegree + otherDegree + 1) { d -> + (max(0, d - otherDegree)..min(thisDegree, d)) + .map { coefficients[it] * other.coefficients[d - it] } + .reduce { acc, rational -> acc + rational } + } ) } } @@ -431,6 +439,42 @@ public open class PolynomialSpace>( public inline operator fun Polynomial.invoke(argument: C): C = this.substitute(ring, argument) @Suppress("NOTHING_TO_INLINE") public inline operator fun Polynomial.invoke(argument: Polynomial): Polynomial = this.substitute(ring, argument) + + // TODO: Move to other internal utilities with context receiver + @JvmName("applyAndRemoveZerosInternal") + internal inline fun MutableList.applyAndRemoveZeros(block: MutableList.() -> Unit) : MutableList { + contract { + callsInPlace(block, InvocationKind.EXACTLY_ONCE) + } + block() + while (elementAt(lastIndex).isZero()) removeAt(lastIndex) + return this + } + internal inline fun List.applyAndRemoveZeros(block: MutableList.() -> Unit) : List = + toMutableList().applyAndRemoveZeros(block) + internal inline fun MutableCoefficients(size: Int, init: (index: Int) -> C): MutableList { + val list = ArrayList(size) + repeat(size) { index -> list.add(init(index)) } + with(list) { while (elementAt(lastIndex).isZero()) removeAt(lastIndex) } + return list + } + internal inline fun Coefficients(size: Int, init: (index: Int) -> C): List = MutableCoefficients(size, init) + @OptIn(ExperimentalTypeInference::class) + internal inline fun buildCoefficients(@BuilderInference builderAction: MutableList.() -> Unit): List { + contract { callsInPlace(builderAction, InvocationKind.EXACTLY_ONCE) } + return buildList { + builderAction() + while (elementAt(lastIndex).isZero()) removeAt(lastIndex) + } + } + @OptIn(ExperimentalTypeInference::class) + internal inline fun buildCoefficients(capacity: Int, @BuilderInference builderAction: MutableList.() -> Unit): List { + contract { callsInPlace(builderAction, InvocationKind.EXACTLY_ONCE) } + return buildList(capacity) { + builderAction() + while (elementAt(lastIndex).isZero()) removeAt(lastIndex) + } + } } /**