forked from kscience/kmath
Provide basic tests for complex numbers, also fix complex division
This commit is contained in:
parent
c4a2489b81
commit
b77bfeb372
@ -8,15 +8,41 @@ import scientifik.memory.MemorySpec
|
||||
import scientifik.memory.MemoryWriter
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* This complex's conjugate.
|
||||
*/
|
||||
val Complex.conjugate: Complex
|
||||
get() = Complex(re, -im)
|
||||
|
||||
/**
|
||||
* This complex's reciprocal.
|
||||
*/
|
||||
val Complex.reciprocal: Complex
|
||||
get() {
|
||||
val scale = re * re + im * im
|
||||
return Complex(re / scale, -im / scale)
|
||||
}
|
||||
|
||||
/**
|
||||
* Absolute value of complex number.
|
||||
*/
|
||||
val Complex.r: Double
|
||||
get() = sqrt(re * re + im * im)
|
||||
|
||||
/**
|
||||
* An angle between vector represented by complex number and X axis.
|
||||
*/
|
||||
val Complex.theta: Double
|
||||
get() = atan(im / re)
|
||||
|
||||
private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||
|
||||
/**
|
||||
* A field of [Complex].
|
||||
*/
|
||||
object ComplexField : ExtendedField<Complex> {
|
||||
override val zero: Complex = Complex(0.0, 0.0)
|
||||
|
||||
override val one: Complex = Complex(1.0, 0.0)
|
||||
object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
|
||||
override val zero: Complex = 0.0.toComplex()
|
||||
override val one: Complex = 1.0.toComplex()
|
||||
|
||||
/**
|
||||
* The imaginary unit.
|
||||
@ -30,19 +56,53 @@ object ComplexField : ExtendedField<Complex> {
|
||||
override fun multiply(a: Complex, b: Complex): Complex =
|
||||
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||
|
||||
override fun divide(a: Complex, b: Complex): Complex {
|
||||
val norm = b.re * b.re + b.im * b.im
|
||||
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
||||
override fun divide(a: Complex, b: Complex): Complex = when {
|
||||
b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN)
|
||||
|
||||
(if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> {
|
||||
val wr = b.im / b.re
|
||||
val wd = b.re + wr * b.im
|
||||
|
||||
if (wd.isNaN() || wd == 0.0)
|
||||
Complex(Double.NaN, Double.NaN)
|
||||
else
|
||||
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
|
||||
}
|
||||
|
||||
b.im == 0.0 -> Complex(Double.NaN, Double.NaN)
|
||||
|
||||
else -> {
|
||||
val wr = b.re / b.im
|
||||
val wd = b.im + wr * b.re
|
||||
|
||||
if (wd.isNaN() || wd == 0.0)
|
||||
Complex(Double.NaN, Double.NaN)
|
||||
else
|
||||
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
|
||||
}
|
||||
}
|
||||
|
||||
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
|
||||
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
|
||||
override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg)
|
||||
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg)
|
||||
override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2
|
||||
|
||||
override fun power(arg: Complex, pow: Number): Complex =
|
||||
arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
|
||||
override fun tan(arg: Complex): Complex {
|
||||
val e1 = exp(-i * arg)
|
||||
val e2 = exp(i * arg)
|
||||
return i * (e1 - e2) / (e1 + e2)
|
||||
}
|
||||
|
||||
override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg)
|
||||
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg)
|
||||
|
||||
override fun atan(arg: Complex): Complex {
|
||||
val iArg = i * arg
|
||||
return i * (ln(1 - iArg) - ln(1 + iArg)) / 2
|
||||
}
|
||||
|
||||
override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0)
|
||||
arg.re.pow(pow.toDouble()).toComplex()
|
||||
else
|
||||
exp(pow * ln(arg))
|
||||
|
||||
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
|
||||
|
||||
@ -93,6 +153,8 @@ object ComplexField : ExtendedField<Complex> {
|
||||
*/
|
||||
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
|
||||
|
||||
override fun norm(arg: Complex): Complex = arg.conjugate * arg
|
||||
|
||||
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
|
||||
}
|
||||
|
||||
@ -105,12 +167,12 @@ object ComplexField : ExtendedField<Complex> {
|
||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>, Comparable<Complex> {
|
||||
constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||
|
||||
override val context: ComplexField get() = ComplexField
|
||||
|
||||
override fun unwrap(): Complex = this
|
||||
|
||||
override fun Complex.wrap(): Complex = this
|
||||
|
||||
override val context: ComplexField get() = ComplexField
|
||||
|
||||
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
||||
|
||||
companion object : MemorySpec<Complex> {
|
||||
@ -126,28 +188,13 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A complex conjugate
|
||||
*/
|
||||
val Complex.conjugate: Complex get() = Complex(re, -im)
|
||||
|
||||
/**
|
||||
* Absolute value of complex number
|
||||
*/
|
||||
val Complex.r: Double get() = sqrt(re * re + im * im)
|
||||
|
||||
/**
|
||||
* An angle between vector represented by complex number and X axis
|
||||
*/
|
||||
val Complex.theta: Double get() = atan(im / re)
|
||||
|
||||
/**
|
||||
* Creates a complex number with real part equal to this real.
|
||||
*
|
||||
* @receiver the real part.
|
||||
* @return the new complex number.
|
||||
*/
|
||||
fun Double.toComplex(): Complex = Complex(this, 0.0)
|
||||
fun Number.toComplex(): Complex = Complex(this, 0.0)
|
||||
|
||||
inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer<Complex> {
|
||||
return MemoryBuffer.create(Complex, size, init)
|
||||
|
@ -0,0 +1,46 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class ComplexFieldTest {
|
||||
@Test
|
||||
fun testAddition() {
|
||||
assertEquals(Complex(42, 42), ComplexField { Complex(16, 16) + Complex(26, 26) })
|
||||
assertEquals(Complex(42, 16), ComplexField { Complex(16, 16) + 26 })
|
||||
assertEquals(Complex(42, 16), ComplexField { 26 + Complex(16, 16) })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSubtraction() {
|
||||
assertEquals(Complex(42, 42), ComplexField { Complex(86, 55) - Complex(44, 13) })
|
||||
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
|
||||
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMultiplication() {
|
||||
assertEquals(Complex(42, 42), ComplexField { Complex(4.2, 0) * Complex(10, 10) })
|
||||
assertEquals(Complex(42, 21), ComplexField { Complex(4.2, 2.1) * 10 })
|
||||
assertEquals(Complex(42, 21), ComplexField { 10 * Complex(4.2, 2.1) })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDivision() {
|
||||
assertEquals(Complex(42, 42), ComplexField { Complex(0, 168) / Complex(2, 2) })
|
||||
assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 })
|
||||
assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) })
|
||||
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(Double.NaN, Double.NaN) })
|
||||
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPower() {
|
||||
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
|
||||
assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
|
||||
|
||||
assertEquals(
|
||||
ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() },
|
||||
ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() })
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class ComplexTest {
|
||||
@Test
|
||||
fun conjugate() {
|
||||
assertEquals(
|
||||
Complex(0, -42), (ComplexField.i * 42).conjugate
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun reciprocal() {
|
||||
assertEquals(Complex(0.5, -0.0), 2.toComplex().reciprocal)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun r() {
|
||||
assertEquals(kotlin.math.sqrt(2.0), (ComplexField.i + 1.0.toComplex()).r)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun theta() {
|
||||
assertEquals(0.0, 1.toComplex().theta)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun toComplex() {
|
||||
assertEquals(Complex(42, 0), 42.toComplex())
|
||||
assertEquals(Complex(42.0, 0), 42.0.toComplex())
|
||||
assertEquals(Complex(42f, 0), 42f.toComplex())
|
||||
assertEquals(Complex(42.0, 0), 42.0.toComplex())
|
||||
assertEquals(Complex(42.toByte(), 0), 42.toByte().toComplex())
|
||||
assertEquals(Complex(42.toShort(), 0), 42.toShort().toComplex())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user