diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 95cfc1b1d..0d0593d1b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -9,11 +9,20 @@ import scientifik.memory.MemoryWriter import kotlin.math.* /** - * A complex conjugate. + * This complex's conjugate. */ val Complex.conjugate: Complex get() = Complex(re, -im) +/** + * This complex's reciprocal. + */ +val Complex.reciprocal: Complex + get() { + val scale = re * re + im * im + return Complex(re / scale, -im / scale) + } + /** * Absolute value of complex number. */ @@ -46,16 +55,48 @@ object ComplexField : ExtendedField { override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) - override fun divide(a: Complex, b: Complex): Complex { - val scale = b.re * b.re + b.im * b.im - return a * Complex(b.re / scale, -b.im / scale) + override fun divide(a: Complex, b: Complex): Complex = when { + b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN) + + (if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> { + val wr = b.im / b.re + val wd = b.re + wr * b.im + + if (wd.isNaN() || wd == 0.0) + Complex(Double.NaN, Double.NaN) + else + Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd) + } + + b.im == 0.0 -> Complex(Double.NaN, Double.NaN) + + else -> { + val wr = b.re / b.im + val wd = b.im + wr * b.re + + if (wd.isNaN() || wd == 0.0) + Complex(Double.NaN, Double.NaN) + else + Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd) + } } override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 + + override fun tan(arg: Complex): Complex { + val e1 = exp(-i * arg) + val e2 = exp(i * arg) + return i * (e1 - e2) / (e1 + e2) + } + override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg) override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) - override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 + + override fun atan(arg: Complex): Complex { + val iArg = i * arg + return i * (ln(one - iArg) - ln(one + iArg)) / 2 + } override fun sinh(arg: Complex): Complex = (exp(arg) - exp(-arg)) / 2 override fun cosh(arg: Complex): Complex = (exp(arg) + exp(-arg)) / 2 diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt new file mode 100644 index 000000000..ad85fa9aa --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/ComplexTest.kt @@ -0,0 +1,49 @@ +package scientifik.kmath.operations + +import kotlin.math.PI +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class ComplexFieldTest { + @Test + fun testAddition() { + assertEquals(Complex(42, 42), ComplexField { Complex(16, 16) + Complex(26, 26) }) + assertEquals(Complex(42, 16), ComplexField { Complex(16, 16) + 26 }) + assertEquals(Complex(42, 16), ComplexField { 26 + Complex(16, 16) }) + } + + @Test + fun testSubtraction() { + assertEquals(Complex(42, 42), ComplexField { Complex(86, 55) - Complex(44, 13) }) + assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 }) + assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) }) + } + + @Test + fun testMultiplication() { + assertEquals(Complex(42, 42), ComplexField { Complex(4.2, 0) * Complex(10, 10) }) + assertEquals(Complex(42, 21), ComplexField { Complex(4.2, 2.1) * 10 }) + assertEquals(Complex(42, 21), ComplexField { 10 * Complex(4.2, 2.1) }) + } + + @Test + fun testDivision() { + assertEquals(Complex(42, 42), ComplexField { Complex(0, 168) / Complex(2, 2) }) + assertEquals(Complex(42, 56), ComplexField { Complex(86, 56) - 44 }) + assertEquals(Complex(42, 56), ComplexField { 86 - Complex(44, -56) }) + assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(Double.NaN, Double.NaN) }) + assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) }) + } + + @Test + fun testSine() { + assertEquals(Complex(1.2246467991473532E-16, 0), ComplexField { sin(PI.toComplex()) }) + assertEquals(Complex(0, 11.548739357257748), ComplexField { sin(i * PI.toComplex()) }) + assertEquals(Complex(0, 1.1752011936438014), ComplexField { sin(i) }) + } + + @Test + fun testArcsine() { + assertEquals(Complex(0, -0.0), ComplexField { asin(zero) }) + } +}