Provide basic tests for complex numbers, also fix complex division

This commit is contained in:
Iaroslav Postovalov 2020-08-11 14:58:28 +07:00
parent c4a2489b81
commit b77bfeb372
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
3 changed files with 161 additions and 30 deletions

View File

@ -8,15 +8,41 @@ import scientifik.memory.MemorySpec
import scientifik.memory.MemoryWriter
import kotlin.math.*
/**
* This complex's conjugate.
*/
val Complex.conjugate: Complex
get() = Complex(re, -im)
/**
* This complex's reciprocal.
*/
val Complex.reciprocal: Complex
get() {
val scale = re * re + im * im
return Complex(re / scale, -im / scale)
}
/**
* Absolute value of complex number.
*/
val Complex.r: Double
get() = sqrt(re * re + im * im)
/**
* An angle between vector represented by complex number and X axis.
*/
val Complex.theta: Double
get() = atan(im / re)
private val PI_DIV_2 = Complex(PI / 2, 0)
/**
* A field of [Complex].
*/
object ComplexField : ExtendedField<Complex> {
override val zero: Complex = Complex(0.0, 0.0)
override val one: Complex = Complex(1.0, 0.0)
object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
override val zero: Complex = 0.0.toComplex()
override val one: Complex = 1.0.toComplex()
/**
* The imaginary unit.
@ -30,19 +56,53 @@ object ComplexField : ExtendedField<Complex> {
override fun multiply(a: Complex, b: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
override fun divide(a: Complex, b: Complex): Complex {
val norm = b.re * b.re + b.im * b.im
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
override fun divide(a: Complex, b: Complex): Complex = when {
b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN)
(if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> {
val wr = b.im / b.re
val wd = b.re + wr * b.im
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
else
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
}
b.im == 0.0 -> Complex(Double.NaN, Double.NaN)
else -> {
val wr = b.re / b.im
val wd = b.im + wr * b.re
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
else
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
}
}
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
override fun power(arg: Complex, pow: Number): Complex =
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
override fun tan(arg: Complex): Complex {
val e1 = exp(-i * arg)
val e2 = exp(i * arg)
return i * (e1 - e2) / (e1 + e2)
}
override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg)
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg)
override fun atan(arg: Complex): Complex {
val iArg = i * arg
return i * (ln(1 - iArg) - ln(1 + iArg)) / 2
}
override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0)
arg.re.pow(pow.toDouble()).toComplex()
else
exp(pow * ln(arg))
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
@ -93,6 +153,8 @@ object ComplexField : ExtendedField<Complex> {
*/
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
override fun norm(arg: Complex): Complex = arg.conjugate * arg
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
}
@ -105,12 +167,12 @@ object ComplexField : ExtendedField<Complex> {
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
override val context: ComplexField get() = ComplexField
override fun unwrap(): Complex = this
override fun Complex.wrap(): Complex = this
override val context: ComplexField get() = ComplexField
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
companion object : MemorySpec<Complex> {
@ -126,28 +188,13 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
}
}
/**
* A complex conjugate
*/
val Complex.conjugate: Complex get() = Complex(re, -im)
/**
* Absolute value of complex number
*/
val Complex.r: Double get() = sqrt(re * re + im * im)
/**
* An angle between vector represented by complex number and X axis
*/
val Complex.theta: Double get() = atan(im / re)
/**
* Creates a complex number with real part equal to this real.
*
* @receiver the real part.
* @return the new complex number.
*/
fun Double.toComplex(): Complex = Complex(this, 0.0)
fun Number.toComplex(): Complex = Complex(this, 0.0)
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
return MemoryBuffer.create(Complex, size, init)

View File

@ -0,0 +1,46 @@
package scientifik.kmath.operations
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ComplexFieldTest {
@Test
fun testAddition() {
assertEquals(Complex(42, 42), ComplexField { Complex(16, 16) + Complex(26, 26) })
assertEquals(Complex(42, 16), ComplexField { Complex(16, 16) + 26 })
assertEquals(Complex(42, 16), ComplexField { 26 + Complex(16, 16) })
}
@Test
fun testSubtraction() {
assertEquals(Complex(42, 42), ComplexField { Complex(86, 55) - Complex(44, 13) })
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
}
@Test
fun testMultiplication() {
assertEquals(Complex(42, 42), ComplexField { Complex(4.2, 0) * Complex(10, 10) })
assertEquals(Complex(42, 21), ComplexField { Complex(4.2, 2.1) * 10 })
assertEquals(Complex(42, 21), ComplexField { 10 * Complex(4.2, 2.1) })
}
@Test
fun testDivision() {
assertEquals(Complex(42, 42), ComplexField { Complex(0, 168) / Complex(2, 2) })
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(Double.NaN, Double.NaN) })
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) })
}
@Test
fun testPower() {
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
assertEquals(
ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() },
ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() })
}
}

View File

@ -0,0 +1,38 @@
package scientifik.kmath.operations
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ComplexTest {
@Test
fun conjugate() {
assertEquals(
Complex(0, -42), (ComplexField.i * 42).conjugate
)
}
@Test
fun reciprocal() {
assertEquals(Complex(0.5, -0.0), 2.toComplex().reciprocal)
}
@Test
fun r() {
assertEquals(kotlin.math.sqrt(2.0), (ComplexField.i + 1.0.toComplex()).r)
}
@Test
fun theta() {
assertEquals(0.0, 1.toComplex().theta)
}
@Test
fun toComplex() {
assertEquals(Complex(42, 0), 42.toComplex())
assertEquals(Complex(42.0, 0), 42.0.toComplex())
assertEquals(Complex(42f, 0), 42f.toComplex())
assertEquals(Complex(42.0, 0), 42.0.toComplex())
assertEquals(Complex(42.toByte(), 0), 42.toByte().toComplex())
assertEquals(Complex(42.toShort(), 0), 42.toShort().toComplex())
}
}