Make complex and quaternion NaN-hostile

This commit is contained in:
Iaroslav Postovalov 2020-11-10 19:01:26 +07:00
parent b1ccca1019
commit 1539113e72
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
3 changed files with 30 additions and 31 deletions

View File

@ -57,26 +57,24 @@ public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
public override fun divide(a: Complex, b: Complex): Complex = when {
b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN)
abs(b.im) < abs(b.re) -> {
val wr = b.im / b.re
val wd = b.re + wr * b.im
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
throw ArithmeticException("Division by zero or infinity")
else
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
}
b.im == 0.0 -> Complex(Double.NaN, Double.NaN)
b.im == 0.0 -> throw ArithmeticException("Division by zero")
else -> {
val wr = b.re / b.im
val wd = b.im + wr * b.re
if (wd.isNaN() || wd == 0.0)
Complex(Double.NaN, Double.NaN)
throw ArithmeticException("Division by zero or infinity")
else
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
}
@ -130,6 +128,11 @@ public data class Complex(val re: Double, val im: Double) : FieldElement<Complex
public override val context: ComplexField
get() = ComplexField
init {
require(!re.isNaN()) { "Real component of complex is not-a-number" }
require(!im.isNaN()) { "Imaginary component of complex is not-a-number" }
}
public override fun unwrap(): Complex = this
public override fun Complex.wrap(): Complex = this
public override fun compareTo(other: Complex): Int = r.compareTo(other.r)

View File

@ -41,12 +41,12 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
/**
* The `i` quaternion unit.
*/
public val i: Quaternion by lazy { Quaternion(0, 1, 0, 0) }
public val i: Quaternion by lazy { Quaternion(0, 1) }
/**
* The `j` quaternion unit.
*/
public val j: Quaternion by lazy { Quaternion(0, 0, 1, 0) }
public val j: Quaternion by lazy { Quaternion(0, 0, 1) }
/**
* The `k` quaternion unit.
@ -104,41 +104,23 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
}
}
private inline fun pwr2(x: Quaternion): Quaternion {
private fun pwr2(x: Quaternion): Quaternion {
val aa = 2 * x.w
return Quaternion(
x.w * x.w - (x.x * x.x + x.y * x.y + x.z * x.z),
aa * x.x,
aa * x.y,
aa * x.z
)
return Quaternion(x.w * x.w - (x.x * x.x + x.y * x.y + x.z * x.z), aa * x.x, aa * x.y, aa * x.z)
}
private inline fun pwr3(x: Quaternion): Quaternion {
private fun pwr3(x: Quaternion): Quaternion {
val a2 = x.w * x.w
val n1 = x.x * x.x + x.y * x.y + x.z * x.z
val n2 = 3.0 * a2 - n1
return Quaternion(
x.w * (a2 - 3 * n1),
x.x * n2,
x.y * n2,
x.z * n2
)
return Quaternion(x.w * (a2 - 3 * n1), x.x * n2, x.y * n2, x.z * n2)
}
private inline fun pwr4(x: Quaternion): Quaternion {
private fun pwr4(x: Quaternion): Quaternion {
val a2 = x.w * x.w
val n1 = x.x * x.x + x.y * x.y + x.z * x.z
val n2 = 4 * x.w * (a2 - n1)
return Quaternion(
a2 * a2 - 6 * a2 * n1 + n1 * n1,
x.x * n2,
x.y * n2,
x.z * n2
)
return Quaternion(a2 * a2 - 6 * a2 * n1 + n1 * n1, x.x * n2, x.y * n2, x.z * n2)
}
public override fun exp(arg: Quaternion): Quaternion {
@ -213,6 +195,13 @@ public data class Quaternion(val w: Double, val x: Double, val y: Double, val z:
public constructor(wx: Complex, yz: Complex) : this(wx.re, wx.im, yz.re, yz.im)
public constructor(wx: Complex) : this(wx.re, wx.im, 0, 0)
init {
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
require(!y.isNaN()) { "x-component of quaternion is not-a-number" }
require(!z.isNaN()) { "x-component of quaternion is not-a-number" }
}
public override val context: QuaternionField
get() = QuaternionField

View File

@ -3,6 +3,13 @@ plugins {
}
kotlin.sourceSets {
all {
languageSettings.apply {
useExperimentalAnnotation("kotlinx.coroutines.FlowPreview")
useExperimentalAnnotation("kotlinx.coroutines.ExperimentalCoroutinesApi")
}
}
commonMain {
dependencies {
api(project(":kmath-coroutines"))