Finished with tests for Polynomial.

This commit is contained in:
Gleb Minaev 2022-03-20 23:22:39 +03:00
parent fbc21101bb
commit 25ec59b985
4 changed files with 407 additions and 37 deletions

View File

@ -386,7 +386,7 @@ public open class PolynomialSpace<C, A : Ring<C>>(
/** /**
* Instance of unit constant (unit of the underlying ring). * Instance of unit constant (unit of the underlying ring).
*/ */
override val one: Polynomial<C> = Polynomial(listOf(constantZero)) override val one: Polynomial<C> = Polynomial(listOf(constantOne))
/** /**
* Checks equality of the polynomials. * Checks equality of the polynomials.
@ -394,11 +394,8 @@ public open class PolynomialSpace<C, A : Ring<C>>(
public override infix fun Polynomial<C>.equalsTo(other: Polynomial<C>): Boolean = public override infix fun Polynomial<C>.equalsTo(other: Polynomial<C>): Boolean =
when { when {
this === other -> true this === other -> true
else -> { this.degree == other.degree -> (0..degree).all { coefficients[it] == other.coefficients[it] }
if (this.degree == other.degree) else -> false
(0..degree).all { coefficients[it] == other.coefficients[it] }
else false
}
} }
/** /**
@ -415,8 +412,8 @@ public open class PolynomialSpace<C, A : Ring<C>>(
with(coefficients) { with(coefficients) {
when { when {
isEmpty() -> constantZero isEmpty() -> constantZero
withIndex().any { (index, c) -> index == 0 || c.isZero() } -> null withIndex().all { (index, c) -> index == 0 || c.isZero() } -> first()
else -> first() else -> null
} }
} }
@ -489,8 +486,6 @@ public open class PolynomialSpace<C, A : Ring<C>>(
public class ScalablePolynomialSpace<C, A>( public class ScalablePolynomialSpace<C, A>(
ring: A, ring: A,
) : PolynomialSpace<C, A>(ring), ScaleOperations<Polynomial<C>> where A : Ring<C>, A : ScaleOperations<C> { ) : PolynomialSpace<C, A>(ring), ScaleOperations<Polynomial<C>> where A : Ring<C>, A : ScaleOperations<C> {
override fun scale(a: Polynomial<C>, value: Double): Polynomial<C> = override fun scale(a: Polynomial<C>, value: Double): Polynomial<C> =
ring { Polynomial(List(a.coefficients.size) { index -> a.coefficients[index] * value }) } ring { Polynomial(a.coefficients.map { scale(it, value) }) }
} }

View File

@ -127,9 +127,10 @@ public fun <C> Polynomial<C>.substitute(ring: Ring<C>, arg: Polynomial<C>) : Pol
if (argDegree == -1) return coefficients[0].asPolynomial() if (argDegree == -1) return coefficients[0].asPolynomial()
val constantZero = zero val constantZero = zero
val resultCoefs: MutableList<C> = MutableList(thisDegree * argDegree + 1) { constantZero } val resultCoefs: MutableList<C> = MutableList(thisDegree * argDegree + 1) { constantZero }
resultCoefs[0] = coefficients[thisDegree]
val resultCoefsUpdate: MutableList<C> = MutableList(thisDegree * argDegree + 1) { constantZero } val resultCoefsUpdate: MutableList<C> = MutableList(thisDegree * argDegree + 1) { constantZero }
var resultDegree = 0 var resultDegree = 0
for (deg in thisDegree downTo 0) { for (deg in thisDegree - 1 downTo 0) {
resultCoefsUpdate[0] = coefficients[deg] resultCoefsUpdate[0] = coefficients[deg]
multiplyAddingToUpdater( multiplyAddingToUpdater(
ring = ring, ring = ring,
@ -142,6 +143,8 @@ public fun <C> Polynomial<C>.substitute(ring: Ring<C>, arg: Polynomial<C>) : Pol
) )
resultDegree += argDegree resultDegree += argDegree
} }
with(resultCoefs) { while (isNotEmpty() && elementAt(lastIndex) == constantZero) removeAt(lastIndex) }
return Polynomial<C>(resultCoefs) return Polynomial<C>(resultCoefs)
} }
@ -162,7 +165,12 @@ public fun <C, A : Ring<C>> Polynomial<C>.asPolynomialFunctionOver(ring: A): (Po
public fun <C, A> Polynomial<C>.derivative( public fun <C, A> Polynomial<C>.derivative(
algebra: A, algebra: A,
): Polynomial<C> where A : Ring<C>, A : NumericAlgebra<C> = algebra { ): Polynomial<C> where A : Ring<C>, A : NumericAlgebra<C> = algebra {
Polynomial(coefficients.drop(1).mapIndexed { index, c -> number(index + 1) * c }) Polynomial(
buildList(max(0, coefficients.size - 1)) {
for (deg in 1 .. coefficients.lastIndex) add(number(deg) * coefficients[deg])
while (isNotEmpty() && elementAt(lastIndex) == algebra.zero) removeAt(lastIndex)
}
)
} }
/** /**
@ -171,9 +179,16 @@ public fun <C, A> Polynomial<C>.derivative(
@UnstableKMathAPI @UnstableKMathAPI
public fun <C, A> Polynomial<C>.nthDerivative( public fun <C, A> Polynomial<C>.nthDerivative(
algebra: A, algebra: A,
order: UInt, order: Int,
): Polynomial<C> where A : Ring<C>, A : NumericAlgebra<C> = algebra { ): Polynomial<C> where A : Ring<C>, A : NumericAlgebra<C> = algebra {
Polynomial(coefficients.drop(order.toInt()).mapIndexed { index, c -> (index + 1..index + order.toInt()).fold(c) { acc, i -> acc * number(i) } }) require(order >= 0) { "Order of derivative must be non-negative" }
Polynomial(
buildList(max(0, coefficients.size - order)) {
for (deg in order.. coefficients.lastIndex)
add((deg - order + 1 .. deg).fold(coefficients[deg]) { acc, d -> acc * number(d) })
while (isNotEmpty() && elementAt(lastIndex) == algebra.zero) removeAt(lastIndex)
}
)
} }
/** /**
@ -183,11 +198,13 @@ public fun <C, A> Polynomial<C>.nthDerivative(
public fun <C, A> Polynomial<C>.antiderivative( public fun <C, A> Polynomial<C>.antiderivative(
algebra: A, algebra: A,
): Polynomial<C> where A : Field<C>, A : NumericAlgebra<C> = algebra { ): Polynomial<C> where A : Field<C>, A : NumericAlgebra<C> = algebra {
val integratedCoefficients = buildList(coefficients.size + 1) { Polynomial(
add(zero) buildList(coefficients.size + 1) {
coefficients.mapIndexedTo(this) { index, t -> t / number(index + 1) } add(zero)
} coefficients.mapIndexedTo(this) { index, t -> t / number(index + 1) }
Polynomial(integratedCoefficients) while (isNotEmpty() && elementAt(lastIndex) == algebra.zero) removeAt(lastIndex)
}
)
} }
/** /**
@ -196,13 +213,16 @@ public fun <C, A> Polynomial<C>.antiderivative(
@UnstableKMathAPI @UnstableKMathAPI
public fun <C, A> Polynomial<C>.nthAntiderivative( public fun <C, A> Polynomial<C>.nthAntiderivative(
algebra: A, algebra: A,
order: UInt, order: Int,
): Polynomial<C> where A : Field<C>, A : NumericAlgebra<C> = algebra { ): Polynomial<C> where A : Field<C>, A : NumericAlgebra<C> = algebra {
val newCoefficients = buildList(coefficients.size + order.toInt()) { require(order >= 0) { "Order of antiderivative must be non-negative" }
repeat(order.toInt()) { add(zero) } Polynomial(
coefficients.mapIndexedTo(this) { index, c -> (1..order.toInt()).fold(c) { acc, i -> acc / number(index + i) } } buildList(coefficients.size + order) {
} repeat(order) { add(zero) }
return Polynomial(newCoefficients) coefficients.mapIndexedTo(this) { index, c -> (1..order).fold(c) { acc, i -> acc / number(index + i) } }
while (isNotEmpty() && elementAt(lastIndex) == algebra.zero) removeAt(lastIndex)
}
)
} }
/** /**

View File

@ -534,4 +534,158 @@ class PolynomialTest {
assertFalse("test 12") { Polynomial(Rational(-1), Rational(5, -5), Rational(0)).isMinusOne() } assertFalse("test 12") { Polynomial(Rational(-1), Rational(5, -5), Rational(0)).isMinusOne() }
} }
} }
@Test
fun test_Polynomial_equalsTo() {
RationalField.polynomial {
assertTrue("test 1") {
Polynomial(Rational(5, 9), Rational(-8, 9), Rational(-8, 7)) equalsTo
Polynomial(Rational(5, 9), Rational(-8, 9), Rational(-8, 7))
}
assertTrue("test 2") {
Polynomial(Rational(5, 9), Rational(0), Rational(-8, 7)) equalsTo
Polynomial(Rational(5, 9), Rational(0), Rational(-8, 7))
}
assertTrue("test 3") {
Polynomial(Rational(0), Rational(0), Rational(-8, 7), Rational(0), Rational(0)) equalsTo
Polynomial(Rational(0), Rational(0), Rational(-8, 7))
}
assertFalse("test 4") {
Polynomial(Rational(5, 9), Rational(5, 7), Rational(-8, 7)) equalsTo
Polynomial(Rational(5, 9), Rational(-8, 9), Rational(-8, 7))
}
assertFalse("test 5") {
Polynomial(Rational(8, 3), Rational(0), Rational(-8, 7)) equalsTo
Polynomial(Rational(5, 9), Rational(0), Rational(-8, 7))
}
assertFalse("test 6") {
Polynomial(Rational(0), Rational(0), Rational(-8, 7), Rational(0), Rational(0)) equalsTo
Polynomial(Rational(0), Rational(0), Rational(8, 7))
}
}
}
@Test
fun test_Polynomial_degree() {
RationalField.polynomial {
assertEquals(
2,
Polynomial(Rational(5, 9), Rational(-8, 9), Rational(-8, 7)).degree,
"test 1"
)
assertEquals(
-1,
Polynomial<Rational>().degree,
"test 2"
)
assertEquals(
-1,
Polynomial(Rational(0)).degree,
"test 3"
)
assertEquals(
-1,
Polynomial(Rational(0), Rational(0)).degree,
"test 4"
)
assertEquals(
-1,
Polynomial(Rational(0), Rational(0), Rational(0)).degree,
"test 5"
)
assertEquals(
0,
Polynomial(Rational(5, 9)).degree,
"test 6"
)
assertEquals(
0,
Polynomial(Rational(5, 9), Rational(0)).degree,
"test 7"
)
assertEquals(
0,
Polynomial(Rational(5, 9), Rational(0), Rational(0)).degree,
"test 8"
)
assertEquals(
2,
Polynomial(Rational(0), Rational(0), Rational(-8, 7)).degree,
"test 9"
)
assertEquals(
2,
Polynomial(Rational(5, 9), Rational(-8, 9), Rational(-8, 7), Rational(0), Rational(0)).degree,
"test 10"
)
assertEquals(
2,
Polynomial(Rational(0), Rational(0), Rational(-8, 7), Rational(0), Rational(0)).degree,
"test 11"
)
}
}
@Test
fun test_Polynomial_asConstantOrNull() {
RationalField.polynomial {
assertEquals(
Rational(0),
Polynomial<Rational>().asConstantOrNull(),
"test 1"
)
assertEquals(
Rational(0),
Polynomial(Rational(0)).asConstantOrNull(),
"test 2"
)
assertEquals(
Rational(0),
Polynomial(Rational(0), Rational(0)).asConstantOrNull(),
"test 3"
)
assertEquals(
Rational(0),
Polynomial(Rational(0), Rational(0), Rational(0)).asConstantOrNull(),
"test 4"
)
assertEquals(
Rational(-7, 9),
Polynomial(Rational(-7, 9)).asConstantOrNull(),
"test 5"
)
assertEquals(
Rational(-7, 9),
Polynomial(Rational(-7, 9), Rational(0)).asConstantOrNull(),
"test 6"
)
assertEquals(
Rational(-7, 9),
Polynomial(Rational(-7, 9), Rational(0), Rational(0)).asConstantOrNull(),
"test 7"
)
assertEquals(
null,
Polynomial(Rational(0), Rational(-7, 9)).asConstantOrNull(),
"test 8"
)
assertEquals(
null,
Polynomial(Rational(0), Rational(-7, 9), Rational(0)).asConstantOrNull(),
"test 9"
)
assertEquals(
null,
Polynomial(Rational(0), Rational(0), Rational(-7, 9)).asConstantOrNull(),
"test 10"
)
assertEquals(
null,
Polynomial(Rational(4, 15), Rational(0), Rational(-7, 9)).asConstantOrNull(),
"test 11"
)
assertEquals(
null,
Polynomial(Rational(4, 15), Rational(0), Rational(-7, 9), Rational(0)).asConstantOrNull(),
"test 12"
)
}
}
} }

View File

@ -5,23 +5,224 @@
package space.kscience.kmath.functions package space.kscience.kmath.functions
import space.kscience.kmath.operations.algebra import space.kscience.kmath.test.misc.Rational
import space.kscience.kmath.test.misc.RationalField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
class PolynomialUtilTest { class PolynomialUtilTest {
@Test @Test
fun simple_polynomial_test() { fun test_substitute_Double() {
val polynomial : Polynomial<Double>
Double.algebra.scalablePolynomial {
val x = Polynomial(listOf(0.0, 1.0))
polynomial = x * x - 2 * x + 1
}
assertEquals(0.0, polynomial.substitute(1.0), 0.001)
}
@Test
fun testIntegration() {
val polynomial = Polynomial(1.0, -2.0, 1.0) val polynomial = Polynomial(1.0, -2.0, 1.0)
assertEquals(0.0, polynomial.substitute(1.0), 0.001) assertEquals(0.0, polynomial.substitute(1.0), 0.001)
} }
@Test
fun test_substitute_Constant() {
assertEquals(
Rational(0),
Polynomial(Rational(1), Rational(-2), Rational(1)).substitute(RationalField, Rational(1)),
"test 1"
)
assertEquals(
Rational(25057, 21000),
Polynomial(Rational(5,8), Rational(8, 3), Rational(4, 7), Rational(3, 2))
.substitute(RationalField, Rational(1, 5)),
"test 2"
)
assertEquals(
Rational(2983, 5250),
Polynomial(Rational(0), Rational(8, 3), Rational(4, 7), Rational(3, 2))
.substitute(RationalField, Rational(1, 5)),
"test 3"
)
assertEquals(
Rational(4961, 4200),
Polynomial(Rational(5,8), Rational(8, 3), Rational(4, 7), Rational(0))
.substitute(RationalField, Rational(1, 5)),
"test 4"
)
assertEquals(
Rational(3511, 3000),
Polynomial(Rational(5,8), Rational(8, 3), Rational(0), Rational(3, 2))
.substitute(RationalField, Rational(1, 5)),
"test 5"
)
}
@Test
fun test_substitute_Polynomial() {
assertEquals(
Polynomial(),
Polynomial(Rational(1), Rational(-2), Rational(1)).substitute(RationalField, Polynomial(Rational(1))),
"test 1"
)
assertEquals(
Polynomial(Rational(709, 378), Rational(155, 252), Rational(19, 525), Rational(2, 875)),
Polynomial(Rational(1, 7), Rational(9, 4), Rational(1, 3), Rational(2, 7))
.substitute(RationalField, Polynomial(Rational(6, 9), Rational(1, 5))),
"test 2"
)
assertEquals(
Polynomial(Rational(655, 378), Rational(155, 252), Rational(19, 525), Rational(2, 875)),
Polynomial(Rational(0), Rational(9, 4), Rational(1, 3), Rational(2, 7))
.substitute(RationalField, Polynomial(Rational(6, 9), Rational(1, 5))),
"test 3"
)
assertEquals(
Polynomial(Rational(677, 378), Rational(97, 180), Rational(1, 75)),
Polynomial(Rational(1, 7), Rational(9, 4), Rational(1, 3), Rational(0))
.substitute(RationalField, Polynomial(Rational(6, 9), Rational(1, 5))),
"test 4"
)
assertEquals(
Polynomial(Rational(653, 378), Rational(221, 420), Rational(4, 175), Rational(2, 875)),
Polynomial(Rational(1, 7), Rational(9, 4), Rational(0), Rational(2, 7))
.substitute(RationalField, Polynomial(Rational(6, 9), Rational(1, 5))),
"test 5"
)
assertEquals(
Polynomial(Rational(89, 54)),
Polynomial(Rational(0), Rational(9, 4), Rational(1, 3), Rational(0))
.substitute(RationalField, Polynomial(Rational(6, 9), Rational(0))),
"test 6"
)
}
@Test
fun test_derivative() {
assertEquals(
Polynomial(Rational(-2), Rational(2)),
Polynomial(Rational(1), Rational(-2), Rational(1)).derivative(RationalField),
"test 1"
)
assertEquals(
Polynomial(Rational(-8, 3), Rational(8, 9), Rational(15, 7), Rational(-20, 9)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).derivative(RationalField),
"test 2"
)
assertEquals(
Polynomial(Rational(0), Rational(8, 9), Rational(15, 7), Rational(-20, 9)),
Polynomial(Rational(0), Rational(0), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).derivative(RationalField),
"test 3"
)
assertEquals(
Polynomial(Rational(-8, 3), Rational(8, 9), Rational(15, 7)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(0)).derivative(RationalField),
"test 4"
)
}
@Test
fun test_nthDerivative() {
assertEquals(
Polynomial(Rational(-2), Rational(2)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthDerivative(RationalField, 1),
"test 1"
)
assertFailsWith<IllegalArgumentException>("test2") {
Polynomial(Rational(1), Rational(-2), Rational(1)).nthDerivative(RationalField, -1)
}
assertEquals(
Polynomial(Rational(1), Rational(-2), Rational(1)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthDerivative(RationalField, 0),
"test 3"
)
assertEquals(
Polynomial(Rational(2)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthDerivative(RationalField, 2),
"test 4"
)
assertEquals(
Polynomial(),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthDerivative(RationalField, 3),
"test 5"
)
assertEquals(
Polynomial(),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthDerivative(RationalField, 4),
"test 6"
)
assertEquals(
Polynomial(Rational(8, 9), Rational(30, 7), Rational(-20, 3)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).nthDerivative(RationalField, 2),
"test 7"
)
assertEquals(
Polynomial(Rational(8, 9), Rational(30, 7), Rational(-20, 3)),
Polynomial(Rational(0), Rational(0), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).nthDerivative(RationalField, 2),
"test 8"
)
assertEquals(
Polynomial(Rational(8, 9), Rational(30, 7)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(0)).nthDerivative(RationalField, 2),
"test 9"
)
}
@Test
fun test_antiderivative() {
assertEquals(
Polynomial(Rational(0), Rational(1), Rational(-1), Rational(1, 3)),
Polynomial(Rational(1), Rational(-2), Rational(1)).antiderivative(RationalField),
"test 1"
)
assertEquals(
Polynomial(Rational(0), Rational(1, 5), Rational(-4, 3), Rational(4, 27), Rational(5, 28), Rational(-1, 9)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).antiderivative(RationalField),
"test 2"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(0), Rational(4, 27), Rational(5, 28), Rational(-1, 9)),
Polynomial(Rational(0), Rational(0), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).antiderivative(RationalField),
"test 3"
)
assertEquals(
Polynomial(Rational(0), Rational(1, 5), Rational(-4, 3), Rational(4, 27), Rational(5, 28)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(0)).antiderivative(RationalField),
"test 4"
)
}
@Test
fun test_nthAntiderivative() {
assertEquals(
Polynomial(Rational(0), Rational(1), Rational(-1), Rational(1, 3)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthAntiderivative(RationalField, 1),
"test 1"
)
assertFailsWith<IllegalArgumentException>("test2") {
Polynomial(Rational(1), Rational(-2), Rational(1)).nthAntiderivative(RationalField, -1)
}
assertEquals(
Polynomial(Rational(1), Rational(-2), Rational(1)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthAntiderivative(RationalField, 0),
"test 3"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(1, 2), Rational(-1, 3), Rational(1, 12)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthAntiderivative(RationalField, 2),
"test 4"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(0), Rational(1, 6), Rational(-1, 12), Rational(1, 60)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthAntiderivative(RationalField, 3),
"test 5"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(0), Rational(0), Rational(1, 24), Rational(-1, 60), Rational(1, 360)),
Polynomial(Rational(1), Rational(-2), Rational(1)).nthAntiderivative(RationalField, 4),
"test 6"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(1, 10), Rational(-4, 9), Rational(1, 27), Rational(1, 28), Rational(-1, 54)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).nthAntiderivative(RationalField, 2),
"test 7"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(0), Rational(0), Rational(1, 27), Rational(1, 28), Rational(-1, 54)),
Polynomial(Rational(0), Rational(0), Rational(4, 9), Rational(5, 7), Rational(-5, 9)).nthAntiderivative(RationalField, 2),
"test 8"
)
assertEquals(
Polynomial(Rational(0), Rational(0), Rational(1, 10), Rational(-4, 9), Rational(1, 27), Rational(1, 28)),
Polynomial(Rational(1, 5), Rational(-8, 3), Rational(4, 9), Rational(5, 7), Rational(0)).nthAntiderivative(RationalField, 2),
"test 9"
)
}
} }