diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 54c404f57..64ebe8da3 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -17,7 +17,6 @@ class DerivativeStructureField( ) : ExtendedField { override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } - override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } private val variables: Map = parameters.mapValues { (key, value) -> @@ -60,10 +59,18 @@ class DerivativeStructureField( override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() + override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh() + override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh() + override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh() + override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh() + override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh() + override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh() + override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) is Int -> arg.pow(pow) @@ -71,9 +78,7 @@ class DerivativeStructureField( } fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) - override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() - override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 398ea4395..95cfc1b1d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -8,28 +8,47 @@ import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter import kotlin.math.* +/** + * A complex conjugate. + */ +val Complex.conjugate: Complex + get() = Complex(re, -im) + +/** + * Absolute value of complex number. + */ +val Complex.r: Double + get() = sqrt(re * re + im * im) + +/** + * An angle between vector represented by complex number and X axis. + */ +val Complex.theta: Double + get() = atan(im / re) + private val PI_DIV_2 = Complex(PI / 2, 0) /** - * A field for complex numbers + * A field for complex numbers. */ object ComplexField : ExtendedField { - override val zero: Complex = Complex(0.0, 0.0) + override val zero: Complex = Complex(0, 0) + override val one: Complex = Complex(1, 0) - override val one: Complex = Complex(1.0, 0.0) - - val i = Complex(0.0, 1.0) + /** + * The imaginary unit constant. + */ + val i = Complex(0, 1) override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) - override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) override fun divide(a: Complex, b: Complex): Complex { - val norm = b.re * b.re + b.im * b.im - return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) + val scale = b.re * b.re + b.im * b.im + return a * Complex(b.re / scale, -b.im / scale) } override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 @@ -38,42 +57,40 @@ object ComplexField : ExtendedField { override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 + override fun sinh(arg: Complex): Complex = (exp(arg) - exp(-arg)) / 2 + override fun cosh(arg: Complex): Complex = (exp(arg) + exp(-arg)) / 2 + override fun tanh(arg: Complex): Complex = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) + override fun asinh(arg: Complex): Complex = ln(sqrt(arg pow 2) + arg) + override fun acosh(arg: Complex): Complex = ln(arg + sqrt((arg - 1) * (arg + 1))) + override fun atanh(arg: Complex): Complex = (ln(arg + 1) - ln(1 - arg)) / 2 + override fun power(arg: Complex, pow: Number): Complex = arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im)) - override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re) - operator fun Double.plus(c: Complex) = add(this.toComplex(), c) - - operator fun Double.minus(c: Complex) = add(this.toComplex(), -c) - - operator fun Complex.plus(d: Double) = d + this - - operator fun Complex.minus(d: Double) = add(this, -d.toComplex()) - - operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) - - override fun symbol(value: String): Complex = if (value == "i") { - i - } else { - super.symbol(value) - } + operator fun Double.plus(c: Complex): Complex = add(toComplex(), c) + operator fun Double.minus(c: Complex): Complex = add(toComplex(), -c) + operator fun Complex.plus(d: Double): Complex = d + this + operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex()) + operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) + override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) } /** - * Complex number class + * Complex number class. + * + * @property re the real part of the number. + * @property im the imaginary part of the number. */ data class Complex(val re: Double, val im: Double) : FieldElement, Comparable { constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble()) - override fun unwrap(): Complex = this - - override fun Complex.wrap(): Complex = this - override val context: ComplexField get() = ComplexField + override fun unwrap(): Complex = this + override fun Complex.wrap(): Complex = this override fun compareTo(other: Complex): Int = r.compareTo(other.r) companion object : MemorySpec { @@ -90,26 +107,12 @@ data class Complex(val re: Double, val im: Double) : FieldElement Complex): Buffer = + MemoryBuffer.create(Complex, size, init) -/** - * An angle between vector represented by complex number and X axis - */ -val Complex.theta: Double get() = atan(im / re) - -fun Double.toComplex() = Complex(this, 0.0) - -inline fun Buffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer { - return MemoryBuffer.create(Complex, size, init) -} - -inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer { - return MemoryBuffer.create(Complex, size, init) -} +inline fun MutableBuffer.Companion.complex(size: Int, crossinline init: (Int) -> Complex): Buffer = + MemoryBuffer.create(Complex, size, init) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 953c5a112..9f137788c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -7,23 +7,31 @@ import kotlin.math.pow as kpow * Advanced Number-like field that implements basic operations */ interface ExtendedFieldOperations : - InverseTrigonometricOperations, + TrigonometricOperations, + HyperbolicTrigonometricOperations, PowerOperations, ExponentialOperations { override fun tan(arg: T): T = sin(arg) / cos(arg) + override fun tanh(arg: T): T = sinh(arg) / cosh(arg) override fun unaryOperation(operation: String, arg: T): T = when (operation) { TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.TAN_OPERATION -> tan(arg) - InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) - InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) - InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) + TrigonometricOperations.ACOS_OPERATION -> acos(arg) + TrigonometricOperations.ASIN_OPERATION -> asin(arg) + TrigonometricOperations.ATAN_OPERATION -> atan(arg) + HyperbolicTrigonometricOperations.COSH_OPERATION -> cos(arg) + HyperbolicTrigonometricOperations.SINH_OPERATION -> sin(arg) + HyperbolicTrigonometricOperations.TANH_OPERATION -> tan(arg) + HyperbolicTrigonometricOperations.ACOSH_OPERATION -> acos(arg) + HyperbolicTrigonometricOperations.ASINH_OPERATION -> asin(arg) + HyperbolicTrigonometricOperations.ATANH_OPERATION -> atan(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + else -> super.unaryOperation(operation, arg) } } @@ -40,12 +48,13 @@ interface ExtendedField : ExtendedFieldOperations, Field { * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 */ inline class Real(val value: Double) : FieldElement { + override val context: RealField + get() = RealField + override fun unwrap(): Double = value override fun Double.wrap(): Real = Real(value) - override val context get() = RealField - companion object } @@ -54,72 +63,86 @@ inline class Real(val value: Double) : FieldElement { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object RealField : ExtendedField, Norm { - override val zero: Double = 0.0 - override inline fun add(a: Double, b: Double) = a + b - override inline fun multiply(a: Double, b: Double) = a * b - override inline fun multiply(a: Double, k: Number) = a * k.toDouble() + override val zero: Double + get() = 0.0 - override val one: Double = 1.0 - override inline fun divide(a: Double, b: Double) = a / b + override val one: Double + get() = 1.0 - override inline fun sin(arg: Double) = kotlin.math.sin(arg) - override inline fun cos(arg: Double) = kotlin.math.cos(arg) + override inline fun add(a: Double, b: Double): Double = a + b + override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() + + override inline fun multiply(a: Double, b: Double): Double = a * b + + override inline fun divide(a: Double, b: Double): Double = a / b + + override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) + override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) - override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) + override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) + override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) + override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) + override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) + override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) + override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) - override inline fun exp(arg: Double) = kotlin.math.exp(arg) - override inline fun ln(arg: Double) = kotlin.math.ln(arg) + override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) + override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) + override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) - override inline fun norm(arg: Double) = abs(arg) + override inline fun norm(arg: Double): Double = abs(arg) - override inline fun Double.unaryMinus() = -this - - override inline fun Double.plus(b: Double) = this + b - - override inline fun Double.minus(b: Double) = this - b - - override inline fun Double.times(b: Double) = this * b - - override inline fun Double.div(b: Double) = this / b + override inline fun Double.unaryMinus(): Double = -this + override inline fun Double.plus(b: Double): Double = this + b + override inline fun Double.minus(b: Double): Double = this - b + override inline fun Double.times(b: Double): Double = this * b + override inline fun Double.div(b: Double): Double = this / b } @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object FloatField : ExtendedField, Norm { - override val zero: Float = 0f - override inline fun add(a: Float, b: Float) = a + b - override inline fun multiply(a: Float, b: Float) = a * b - override inline fun multiply(a: Float, k: Number) = a * k.toFloat() + override val zero: Float + get() = 0.0f - override val one: Float = 1f - override inline fun divide(a: Float, b: Float) = a / b + override val one: Float + get() = 1.0f - override inline fun sin(arg: Float) = kotlin.math.sin(arg) - override inline fun cos(arg: Float) = kotlin.math.cos(arg) - override inline fun tan(arg: Float) = kotlin.math.tan(arg) - override inline fun acos(arg: Float) = kotlin.math.acos(arg) - override inline fun asin(arg: Float) = kotlin.math.asin(arg) - override inline fun atan(arg: Float) = kotlin.math.atan(arg) + override inline fun add(a: Float, b: Float): Float = a + b + override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() - override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) + override inline fun multiply(a: Float, b: Float): Float = a * b - override inline fun exp(arg: Float) = kotlin.math.exp(arg) - override inline fun ln(arg: Float) = kotlin.math.ln(arg) + override inline fun divide(a: Float, b: Float): Float = a / b - override inline fun norm(arg: Float) = abs(arg) + override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) + override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) + override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) + override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) + override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) + override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) - override inline fun Float.unaryMinus() = -this + override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) + override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) + override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) + override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) + override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) + override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) - override inline fun Float.plus(b: Float) = this + b + override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) + override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) + override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) - override inline fun Float.minus(b: Float) = this - b + override inline fun norm(arg: Float): Float = abs(arg) - override inline fun Float.times(b: Float) = this * b - - override inline fun Float.div(b: Float) = this / b + override inline fun Float.unaryMinus(): Float = -this + override inline fun Float.plus(b: Float): Float = this + b + override inline fun Float.minus(b: Float): Float = this - b + override inline fun Float.times(b: Float): Float = this * b + override inline fun Float.div(b: Float): Float = this / b } /** @@ -127,20 +150,22 @@ object FloatField : ExtendedField, Norm { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object IntRing : Ring, Norm { - override val zero: Int = 0 - override inline fun add(a: Int, b: Int) = a + b - override inline fun multiply(a: Int, b: Int) = a * b - override inline fun multiply(a: Int, k: Number) = k.toInt() * a - override val one: Int = 1 + override val zero: Int + get() = 0 - override inline fun norm(arg: Int) = abs(arg) + override val one: Int + get() = 1 - override inline fun Int.unaryMinus() = -this + override inline fun add(a: Int, b: Int): Int = a + b + override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a + override inline fun multiply(a: Int, b: Int): Int = a * b + + override inline fun norm(arg: Int): Int = abs(arg) + + override inline fun Int.unaryMinus(): Int = -this override inline fun Int.plus(b: Int): Int = this + b - override inline fun Int.minus(b: Int): Int = this - b - override inline fun Int.times(b: Int): Int = this * b } @@ -149,21 +174,23 @@ object IntRing : Ring, Norm { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object ShortRing : Ring, Norm { - override val zero: Short = 0 - override inline fun add(a: Short, b: Short) = (a + b).toShort() - override inline fun multiply(a: Short, b: Short) = (a * b).toShort() - override inline fun multiply(a: Short, k: Number) = (a * k.toShort()).toShort() - override val one: Short = 1 + override val zero: Short + get() = 0 + + override val one: Short + get() = 1 + + override inline fun add(a: Short, b: Short): Short = (a + b).toShort() + override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() + + override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() - override inline fun Short.unaryMinus() = (-this).toShort() - - override inline fun Short.plus(b: Short) = (this + b).toShort() - - override inline fun Short.minus(b: Short) = (this - b).toShort() - - override inline fun Short.times(b: Short) = (this * b).toShort() + override inline fun Short.unaryMinus(): Short = (-this).toShort() + override inline fun Short.plus(b: Short): Short = (this + b).toShort() + override inline fun Short.minus(b: Short): Short = (this - b).toShort() + override inline fun Short.times(b: Short): Short = (this * b).toShort() } /** @@ -171,21 +198,23 @@ object ShortRing : Ring, Norm { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object ByteRing : Ring, Norm { - override val zero: Byte = 0 - override inline fun add(a: Byte, b: Byte) = (a + b).toByte() - override inline fun multiply(a: Byte, b: Byte) = (a * b).toByte() - override inline fun multiply(a: Byte, k: Number) = (a * k.toByte()).toByte() - override val one: Byte = 1 + override val zero: Byte + get() = 0 + + override val one: Byte + get() = 1 + + override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() + override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() + + override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() - override inline fun Byte.unaryMinus() = (-this).toByte() - - override inline fun Byte.plus(b: Byte) = (this + b).toByte() - - override inline fun Byte.minus(b: Byte) = (this - b).toByte() - - override inline fun Byte.times(b: Byte) = (this * b).toByte() + override inline fun Byte.unaryMinus(): Byte = (-this).toByte() + override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() + override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() + override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() } /** @@ -193,19 +222,21 @@ object ByteRing : Ring, Norm { */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") object LongRing : Ring, Norm { - override val zero: Long = 0 - override inline fun add(a: Long, b: Long) = (a + b) - override inline fun multiply(a: Long, b: Long) = (a * b) - override inline fun multiply(a: Long, k: Number) = a * k.toLong() - override val one: Long = 1 + override val zero: Long + get() = 0 + + override val one: Long + get() = 1 + + override inline fun add(a: Long, b: Long): Long = a + b + override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() + + override inline fun multiply(a: Long, b: Long): Long = a * b override fun norm(arg: Long): Long = abs(arg) - override inline fun Long.unaryMinus() = (-this) - - override inline fun Long.plus(b: Long) = (this + b) - - override inline fun Long.minus(b: Long) = (this - b) - - override inline fun Long.times(b: Long) = (this * b) + override inline fun Long.unaryMinus(): Long = (-this) + override inline fun Long.plus(b: Long): Long = (this + b) + override inline fun Long.minus(b: Long): Long = (this - b) + override inline fun Long.times(b: Long): Long = (this * b) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index 709f0260f..bd1b8efab 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -1,33 +1,46 @@ package scientifik.kmath.operations - -/* Trigonometric operations */ - /** * A container for trigonometric operations for specific type. Trigonometric operations are limited to fields. * * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. - * It also allows to override behavior for optional operations - * + * It also allows to override behavior for optional operations. */ interface TrigonometricOperations : FieldOperations { + /** + * Computes the sine of [arg] . + */ fun sin(arg: T): T + + /** + * Computes the cosine of [arg]. + */ fun cos(arg: T): T + + /** + * Computes the tangent of [arg]. + */ fun tan(arg: T): T + /** + * Computes the inverse sine of [arg]. + */ + fun asin(arg: T): T + + /** + * Computes the inverse cosine of [arg]. + */ + fun acos(arg: T): T + + /** + * Computes the inverse tangent of [arg]. + */ + fun atan(arg: T): T + companion object { const val SIN_OPERATION = "sin" const val COS_OPERATION = "cos" const val TAN_OPERATION = "tan" - } -} - -interface InverseTrigonometricOperations : TrigonometricOperations { - fun asin(arg: T): T - fun acos(arg: T): T - fun atan(arg: T): T - - companion object { const val ASIN_OPERATION = "asin" const val ACOS_OPERATION = "acos" const val ATAN_OPERATION = "atan" @@ -37,11 +50,64 @@ interface InverseTrigonometricOperations : TrigonometricOperations { fun >> sin(arg: T): T = arg.context.sin(arg) fun >> cos(arg: T): T = arg.context.cos(arg) fun >> tan(arg: T): T = arg.context.tan(arg) -fun >> asin(arg: T): T = arg.context.asin(arg) -fun >> acos(arg: T): T = arg.context.acos(arg) -fun >> atan(arg: T): T = arg.context.atan(arg) +fun >> asin(arg: T): T = arg.context.asin(arg) +fun >> acos(arg: T): T = arg.context.acos(arg) +fun >> atan(arg: T): T = arg.context.atan(arg) -/* Power and roots */ +/** + * A container for hyperbolic trigonometric operations for specific type. Trigonometric operations are limited to + * fields. + * + * The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. It + * also allows to override behavior for optional operations. + */ +interface HyperbolicTrigonometricOperations : FieldOperations { + /** + * Computes the hyperbolic sine of [arg]. + */ + fun sinh(arg: T): T + + /** + * Computes the hyperbolic cosine of [arg]. + */ + fun cosh(arg: T): T + + /** + * Computes the hyperbolic tangent of [arg]. + */ + fun tanh(arg: T): T + + /** + * Computes the inverse hyperbolic sine of [arg]. + */ + fun asinh(arg: T): T + + /** + * Computes the inverse hyperbolic cosine of [arg]. + */ + fun acosh(arg: T): T + + /** + * Computes the inverse hyperbolic tangent of [arg]. + */ + fun atanh(arg: T): T + + companion object { + const val SINH_OPERATION = "sinh" + const val COSH_OPERATION = "cosh" + const val TANH_OPERATION = "tanh" + const val ASINH_OPERATION = "asinh" + const val ACOSH_OPERATION = "acosh" + const val ATANH_OPERATION = "atanh" + } +} + +fun >> sinh(arg: T): T = arg.context.sinh(arg) +fun >> cosh(arg: T): T = arg.context.cosh(arg) +fun >> tanh(arg: T): T = arg.context.tanh(arg) +fun >> asinh(arg: T): T = arg.context.asinh(arg) +fun >> acosh(arg: T): T = arg.context.acosh(arg) +fun >> atanh(arg: T): T = arg.context.atanh(arg) /** * A context extension to include power operations like square roots, etc @@ -62,8 +128,6 @@ infix fun >> T.pow(power: Double): T = co fun >> sqrt(arg: T): T = arg pow 0.5 fun >> sqr(arg: T): T = arg pow 2.0 -/* Exponential */ - interface ExponentialOperations : Algebra { fun exp(arg: T): T fun ln(arg: T): T diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index c7e672c28..85c997c13 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -15,7 +15,6 @@ class ComplexNDField(override val shape: IntArray) : ExtendedNDField> { override val strides: Strides = DefaultStrides(shape) - override val elementContext: ComplexField get() = ComplexField override val zero by lazy { produce { zero } } override val one by lazy { produce { one } } @@ -45,6 +44,7 @@ class ComplexNDField(override val shape: IntArray) : transform: ComplexField.(index: IntArray, Complex) -> Complex ): ComplexNDElement { check(arg) + return BufferedNDFieldElement( this, buildBuffer(arg.strides.linearSize) { offset -> @@ -61,6 +61,7 @@ class ComplexNDField(override val shape: IntArray) : transform: ComplexField.(Complex, Complex) -> Complex ): ComplexNDElement { check(a, b) + return BufferedNDFieldElement( this, buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) @@ -69,23 +70,25 @@ class ComplexNDField(override val shape: IntArray) : override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@ComplexNDField, buffer) - override fun power(arg: NDBuffer, pow: Number) = map(arg) { power(it, pow) } + override fun power(arg: NDBuffer, pow: Number): BufferedNDFieldElement = + map(arg) { power(it, pow) } - override fun exp(arg: NDBuffer) = map(arg) { exp(it) } + override fun exp(arg: NDBuffer): BufferedNDFieldElement = map(arg) { exp(it) } + override fun ln(arg: NDBuffer): BufferedNDFieldElement = map(arg) { ln(it) } - override fun ln(arg: NDBuffer) = map(arg) { ln(it) } + override fun sin(arg: NDBuffer): BufferedNDFieldElement = map(arg) { sin(it) } + override fun cos(arg: NDBuffer): BufferedNDFieldElement = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): BufferedNDFieldElement = map(arg) { tan(it) } + override fun asin(arg: NDBuffer): BufferedNDFieldElement = map(arg) { asin(it) } + override fun acos(arg: NDBuffer): BufferedNDFieldElement = map(arg) { acos(it) } + override fun atan(arg: NDBuffer): BufferedNDFieldElement = map(arg) { atan(it) } - override fun sin(arg: NDBuffer) = map(arg) { sin(it) } - - override fun cos(arg: NDBuffer) = map(arg) { cos(it) } - - override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } - - override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } - - override fun acos(arg: NDBuffer): NDBuffer = map(arg) {acos(it)} - - override fun atan(arg: NDBuffer): NDBuffer = map(arg) {atan(it)} + override fun sinh(arg: NDBuffer): BufferedNDFieldElement = map(arg) { sinh(it) } + override fun cosh(arg: NDBuffer): BufferedNDFieldElement = map(arg) { cosh(it) } + override fun tanh(arg: NDBuffer): BufferedNDFieldElement = map(arg) { tanh(it) } + override fun asinh(arg: NDBuffer): BufferedNDFieldElement = map(arg) { asinh(it) } + override fun acosh(arg: NDBuffer): BufferedNDFieldElement = map(arg) { acosh(it) } + override fun atanh(arg: NDBuffer): BufferedNDFieldElement = map(arg) { atanh(it) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 826203d1f..62a92fdb8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -6,18 +6,19 @@ import kotlin.math.* /** - * A simple field over linear buffers of [Double] + * A simple field over linear buffers of [Double]. */ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun add(a: Buffer, b: Buffer): RealBuffer { - require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + require(b.size == a.size) { + "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + } return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else - RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) + } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } override fun multiply(a: Buffer, k: Number): RealBuffer { @@ -26,57 +27,52 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { return if (a is RealBuffer) { val aArray = a.array RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) - } else - RealBuffer(DoubleArray(a.size) { a[it] * kValue }) + } else RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } override fun multiply(a: Buffer, b: Buffer): RealBuffer { - require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + require(b.size == a.size) { + "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + } return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) - } else - RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) + } else RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } override fun divide(a: Buffer, b: Buffer): RealBuffer { - require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + require(b.size == a.size) { + "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " + } return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else - RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) + } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else { - RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - } + } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) - } else { - RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) - } + } else RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array @@ -90,23 +86,50 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } else RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) + override fun sinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { sinh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) + + override fun cosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { cosh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) + + override fun tanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { tanh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) + + override fun asinh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { asinh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) + + override fun acosh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { acosh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) + + override fun atanh(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { atanh(array[it]) }) + } else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) + override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else - RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + } else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else - RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) + } else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } class RealBufferField(val size: Int) : ExtendedField> { @@ -163,6 +186,36 @@ class RealBufferField(val size: Int) : ExtendedField> { return RealBufferFieldOperations.atan(arg) } + override fun sinh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.sinh(arg) + } + + override fun cosh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.cosh(arg) + } + + override fun tanh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.tanh(arg) + } + + override fun asinh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.asinh(arg) + } + + override fun acosh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.acosh(arg) + } + + override fun atanh(arg: Buffer): RealBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.atanh(arg) + } + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c90f90c7..26588b7b1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -12,8 +12,8 @@ class RealNDField(override val shape: IntArray) : override val strides: Strides = DefaultStrides(shape) override val elementContext: RealField get() = RealField - override val zero by lazy { produce { zero } } - override val one by lazy { produce { one } } + override val zero: BufferedNDFieldElement by lazy { produce { zero } } + override val one: RealNDElement by lazy { produce { one } } inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer = RealBuffer(DoubleArray(size) { initializer(it) }) @@ -40,6 +40,7 @@ class RealNDField(override val shape: IntArray) : transform: RealField.(index: IntArray, Double) -> Double ): RealNDElement { check(arg) + return BufferedNDFieldElement( this, buildBuffer(arg.strides.linearSize) { offset -> @@ -64,23 +65,25 @@ class RealNDField(override val shape: IntArray) : override fun NDBuffer.toElement(): FieldElement, *, out BufferedNDField> = BufferedNDFieldElement(this@RealNDField, buffer) - override fun power(arg: NDBuffer, pow: Number) = map(arg) { power(it, pow) } + override fun power(arg: NDBuffer, pow: Number): RealNDElement = map(arg) { power(it, pow) } override fun exp(arg: NDBuffer) = map(arg) { exp(it) } override fun ln(arg: NDBuffer) = map(arg) { ln(it) } - override fun sin(arg: NDBuffer) = map(arg) { sin(it) } + override fun sin(arg: NDBuffer): RealNDElement = map(arg) { sin(it) } + override fun cos(arg: NDBuffer): RealNDElement = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): RealNDElement = map(arg) { tan(it) } + override fun asin(arg: NDBuffer): RealNDElement = map(arg) { asin(it) } + override fun acos(arg: NDBuffer): RealNDElement = map(arg) { acos(it) } + override fun atan(arg: NDBuffer): RealNDElement = map(arg) { atan(it) } - override fun cos(arg: NDBuffer) = map(arg) { cos(it) } - - override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } - - override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } - - override fun acos(arg: NDBuffer): NDBuffer = map(arg) { acos(it) } - - override fun atan(arg: NDBuffer): NDBuffer = map(arg) { atan(it) } + override fun sinh(arg: NDBuffer): RealNDElement = map(arg) { sinh(it) } + override fun cosh(arg: NDBuffer): RealNDElement = map(arg) { cosh(it) } + override fun tanh(arg: NDBuffer): RealNDElement = map(arg) { tanh(it) } + override fun asinh(arg: NDBuffer): RealNDElement = map(arg) { asinh(it) } + override fun acosh(arg: NDBuffer): RealNDElement = map(arg) { acosh(it) } + override fun atanh(arg: NDBuffer): RealNDElement = map(arg) { atanh(it) } } @@ -118,18 +121,14 @@ operator fun Function1.invoke(ndElement: RealNDElement) = /** * Summation operation for [BufferedNDElement] and single element */ -operator fun RealNDElement.plus(arg: Double) = - map { it + arg } +operator fun RealNDElement.plus(arg: Double): RealNDElement = map { it + arg } /** * Subtraction operation between [BufferedNDElement] and single element */ -operator fun RealNDElement.minus(arg: Double) = - map { it - arg } +operator fun RealNDElement.minus(arg: Double): RealNDElement = map { it - arg } /** * Produce a context for n-dimensional operations inside this real field */ -inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R { - return NDField.real(*shape).run(action) -} \ No newline at end of file +inline fun RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)