From 76717c49b1f8083292dd1db7bea8bdbe27fea873 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Tue, 27 Oct 2020 18:06:27 +0700 Subject: [PATCH] Implement fast quaternion implementation, minor changes to complex --- CHANGELOG.md | 1 + .../kscience/kmath/operations/Complex.kt | 70 ++--- .../kscience/kmath/operations/Quaternion.kt | 294 ++++++++++++++++++ 3 files changed, 328 insertions(+), 37 deletions(-) create mode 100644 kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Quaternion.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 214730ecc..2673b121e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Automatic README generation for features (#139) - Native support for `memory`, `core` and `dimensions` - `kmath-ejml` to supply EJML SimpleMatrix wrapper. +- Basic Quaternion vector support. ### Changed - Package changed from `scientifik` to `kscience.kmath`. diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt index 37055a5c8..e9e6fb0ce 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt @@ -42,22 +42,21 @@ private val PI_DIV_2 = Complex(PI / 2, 0) * A field of [Complex]. */ public object ComplexField : ExtendedField, Norm { - override val zero: Complex = 0.0.toComplex() - override val one: Complex = 1.0.toComplex() + override val zero: Complex by lazy { 0.toComplex() } + override val one: Complex by lazy { 1.toComplex() } /** * The imaginary unit. */ - public val i: Complex = Complex(0.0, 1.0) + public val i: Complex by lazy { Complex(0, 1) } - override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) + public override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) + public override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) - override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) - - override fun multiply(a: Complex, b: Complex): Complex = + public override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) - override fun divide(a: Complex, b: Complex): Complex = when { + public override fun divide(a: Complex, b: Complex): Complex = when { b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN) (if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> { @@ -83,31 +82,31 @@ public object ComplexField : ExtendedField, Norm { } } - override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 - override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 + public override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 + public override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 - override fun tan(arg: Complex): Complex { + public override fun tan(arg: Complex): Complex { val e1 = exp(-i * arg) val e2 = exp(i * arg) return i * (e1 - e2) / (e1 + e2) } - override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg) - override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg) + public override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg) + public override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg) - override fun atan(arg: Complex): Complex { + public override fun atan(arg: Complex): Complex { val iArg = i * arg return i * (ln(1 - iArg) - ln(1 + iArg)) / 2 } - override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0) + public override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0) arg.re.pow(pow.toDouble()).toComplex() else exp(pow * ln(arg)) - override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) + public override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) - override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) + public override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) /** * Adds complex number to real one. @@ -116,7 +115,7 @@ public object ComplexField : ExtendedField, Norm { * @param c the augend. * @return the sum. */ - public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c) + public operator fun Double.plus(c: Complex): Complex = add(toComplex(), c) /** * Subtracts complex number from real one. @@ -125,7 +124,7 @@ public object ComplexField : ExtendedField, Norm { * @param c the subtrahend. * @return the difference. */ - public operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c) + public operator fun Double.minus(c: Complex): Complex = add(toComplex(), -c) /** * Adds real number to complex one. @@ -154,9 +153,9 @@ public object ComplexField : ExtendedField, Norm { */ public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) - override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) - - override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) + public override fun Complex.unaryMinus(): Complex = Complex(-re, -im) + public override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg) + public override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } /** @@ -169,26 +168,23 @@ public data class Complex(val re: Double, val im: Double) : FieldElement { public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) - override val context: ComplexField get() = ComplexField - - override fun unwrap(): Complex = this - - override fun Complex.wrap(): Complex = this - - override fun compareTo(other: Complex): Int = r.compareTo(other.r) - - override fun toString(): String { - return "($re + i*$im)" - } + public override val context: ComplexField + get() = ComplexField + public override fun unwrap(): Complex = this + public override fun Complex.wrap(): Complex = this + public override fun compareTo(other: Complex): Int = r.compareTo(other.r) + public override fun toString(): String = "($re + $im * i)" + public override fun minus(b: Complex): Complex = Complex(re - b.re, im - b.im) public companion object : MemorySpec { - override val objectSize: Int + public override val objectSize: Int get() = 16 - override fun MemoryReader.read(offset: Int): Complex = Complex(readDouble(offset), readDouble(offset + 8)) + public override fun MemoryReader.read(offset: Int): Complex = + Complex(readDouble(offset), readDouble(offset + 8)) - override fun MemoryWriter.write(offset: Int, value: Complex) { + public override fun MemoryWriter.write(offset: Int, value: Complex) { writeDouble(offset, value.re) writeDouble(offset + 8, value.im) } @@ -201,7 +197,7 @@ public data class Complex(val re: Double, val im: Double) : FieldElement, Norm, PowerOperations, + ExponentialOperations { + override val zero: Quaternion by lazy { 0.toQuaternion() } + override val one: Quaternion by lazy { 1.toQuaternion() } + + /** + * The `i` quaternion unit. + */ + public val i: Quaternion by lazy { Quaternion(0, 1, 0, 0) } + + /** + * The `j` quaternion unit. + */ + public val j: Quaternion by lazy { Quaternion(0, 0, 1, 0) } + + /** + * The `k` quaternion unit. + */ + public val k: Quaternion by lazy { Quaternion(0, 0, 0, 1) } + + public override fun add(a: Quaternion, b: Quaternion): Quaternion = + Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z) + + public override fun multiply(a: Quaternion, k: Number): Quaternion { + val d = k.toDouble() + return Quaternion(a.w * d, a.x * d, a.y * d, a.z * d) + } + + public override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion( + a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z, + a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y, + a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x, + a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w, + ) + + override fun divide(a: Quaternion, b: Quaternion): Quaternion { + val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z + + return Quaternion( + (b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s, + (b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s, + (b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s, + (b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s, + ) + } + + public override fun power(arg: Quaternion, pow: Number): Quaternion { + if (pow is Int) return pwr(arg, pow) + if (floor(pow.toDouble()) == pow.toDouble()) return pwr(arg, pow.toInt()) + return exp(pow * ln(arg)) + } + + private fun pwr(x: Quaternion, a: Int): Quaternion { + if (a < 0) return -(pwr(x, -a)) + if (a == 0) return one + if (a == 1) return x + if (a == 2) return pwr2(x) + if (a == 3) return pwr3(x) + if (a == 4) return pwr4(x) + val x4 = pwr4(x) + var y = x4 + repeat((1 until a / 4).count()) { y *= x4 } + if (a % 4 == 3) y *= pwr3(x) + if (a % 4 == 2) y *= pwr2(x) + if (a % 4 == 1) y *= x + return y + } + + private inline fun pwr2(x: Quaternion): Quaternion { + val aa = 2 * x.w + + return Quaternion( + x.w * x.w - (x.x * x.x + x.y * x.y + x.z * x.z), + aa * x.x, + aa * x.y, + aa * x.z + ) + } + + private inline fun pwr3(x: Quaternion): Quaternion { + val a2 = x.w * x.w + val n1 = x.x * x.x + x.y * x.y + x.z * x.z + val n2 = 3.0 * a2 - n1 + + return Quaternion( + x.w * (a2 - 3 * n1), + x.x * n2, + x.y * n2, + x.z * n2 + ) + } + + private inline fun pwr4(x: Quaternion): Quaternion { + val a2 = x.w * x.w + val n1 = x.x * x.x + x.y * x.y + x.z * x.z + val n2 = 4 * x.w * (a2 - n1) + + return Quaternion( + a2 * a2 - 6 * a2 * n1 + n1 * n1, + x.x * n2, + x.y * n2, + x.z * n2 + ) + } + + public override fun exp(arg: Quaternion): Quaternion { + val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z + if (un == 0.0) return exp(arg.w).toQuaternion() + val n1 = sqrt(un) + val ea = exp(arg.w) + val n2 = ea * sin(n1) / n1 + return Quaternion(ea * cos(n1), n2 * arg.x, n2 * arg.y, n2 * arg.z) + } + + public override fun ln(arg: Quaternion): Quaternion { + val nu2 = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z + + if (nu2 == 0.0) + return if (arg.w > 0) + Quaternion(ln(arg.w), 0, 0, 0) + else { + val l = ComplexField { ln(arg.w.toComplex()) } + Quaternion(l.re, l.im, 0, 0) + } + + val a = arg.w + check(nu2 > 0) + val n = sqrt(a * a + nu2) + val th = acos(a / n) / sqrt(nu2) + return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z) + } + + /** + * Adds quaternion to real one. + * + * @receiver the addend. + * @param c the augend. + * @return the sum. + */ + public operator fun Double.plus(c: Quaternion): Quaternion = Quaternion(this + c.w, c.x, c.y, c.z) + + /** + * Subtracts quaternion from real one. + * + * @receiver the minuend. + * @param c the subtrahend. + * @return the difference. + */ + public operator fun Double.minus(c: Quaternion): Quaternion = Quaternion(this - c.w, -c.x, -c.y, -c.z) + + /** + * Adds real number to quaternion. + * + * @receiver the addend. + * @param d the augend. + * @return the sum. + */ + public operator fun Quaternion.plus(d: Double): Quaternion = Quaternion(w + d, x, y, z) + + /** + * Subtracts real number from quaternion. + * + * @receiver the minuend. + * @param d the subtrahend. + * @return the difference. + */ + public operator fun Quaternion.minus(d: Double): Quaternion = Quaternion(w - d, x, y, z) + + /** + * Multiplies real number by quaternion. + * + * @receiver the multiplier. + * @param c the multiplicand. + * @receiver the product. + */ + public operator fun Double.times(c: Quaternion): Quaternion = + Quaternion(this * c.w, this * c.x, this * c.y, this * c.z) + + public override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) + public override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) + + public override fun symbol(value: String): Quaternion = when (value) { + "i" -> i + "j" -> j + "k" -> k + else -> super.symbol(value) + } +} + +/** + * Represents `double`-based quaternion. + * + * @property w The first component. + * @property x The second component. + * @property y The third component. + * @property z The fourth component. + */ +public data class Quaternion(val w: Double, val x: Double, val y: Double, val z: Double) : + FieldElement, + Comparable { + public constructor(w: Number, x: Number, y: Number, z: Number) : this( + w.toDouble(), + x.toDouble(), + y.toDouble(), + z.toDouble() + ) + + public constructor(wx: Complex, yz: Complex) : this(wx.re, wx.im, yz.re, yz.im) + + public override val context: QuaternionField + get() = QuaternionField + + public override fun div(k: Number): Quaternion { + val d = k.toDouble() + return Quaternion(w / d, x / d, y / d, z / d) + } + + public override fun unwrap(): Quaternion = this + public override fun Quaternion.wrap(): Quaternion = this + public override fun compareTo(other: Quaternion): Int = r.compareTo(other.r) + public override fun toString(): String = "($w + $x * i + $y * j + $z * k)" + + public companion object : MemorySpec { + public override val objectSize: Int + get() = 32 + + public override fun MemoryReader.read(offset: Int): Quaternion = + Quaternion(readDouble(offset), readDouble(offset + 8), readDouble(offset + 16), readDouble(offset + 24)) + + public override fun MemoryWriter.write(offset: Int, value: Quaternion) { + writeDouble(offset, value.w) + writeDouble(offset + 8, value.x) + writeDouble(offset + 16, value.y) + writeDouble(offset + 24, value.z) + } + } +} + +/** + * Creates a quaternion with real part equal to this real. + * + * @receiver the real part. + * @return the new quaternion. + */ +public fun Number.toQuaternion(): Quaternion = Quaternion(this, 0, 0, 0) + +/** + * Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the + * specified [init] function. + */ +public inline fun Buffer.Companion.quaternion(size: Int, init: (Int) -> Quaternion): Buffer = + MemoryBuffer.create(Quaternion, size, init) + +/** + * Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the + * specified [init] function. + */ +public inline fun MutableBuffer.Companion.quaternion(size: Int, init: (Int) -> Quaternion): MutableBuffer = + MutableMemoryBuffer.create(Quaternion, size, init)