Merge pull request #118 from mipt-npm/hyp-trig-functions

Implement hyperbolic functions for various Algebras
This commit is contained in:
Iaroslav Postovalov 2020-08-11 22:18:05 +07:00 committed by GitHub
commit 5b215833ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 373 additions and 146 deletions

View File

@ -84,9 +84,9 @@ object MstExtendedField : ExtendedField<MST> {
override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun sin(arg: MST): MST = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun cos(arg: MST): MST = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) override fun asin(arg: MST): MST = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun acos(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg) override fun acos(arg: MST): MST = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: MST): MST = unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg) override fun atan(arg: MST): MST = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun add(a: MST, b: MST): MST = MstField.add(a, b) override fun add(a: MST, b: MST): MST = MstField.add(a, b)
override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k) override fun multiply(a: MST, k: Number): MST = MstField.multiply(a, k)
override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b) override fun multiply(a: MST, b: MST): MST = MstField.multiply(a, b)

View File

@ -17,7 +17,6 @@ class DerivativeStructureField(
) : ExtendedField<DerivativeStructure> { ) : ExtendedField<DerivativeStructure> {
override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) } override val zero: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size) }
override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) } override val one: DerivativeStructure by lazy { DerivativeStructure(order, parameters.size, 1.0) }
private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) -> private val variables: Map<String, DerivativeStructure> = parameters.mapValues { (key, value) ->
@ -60,10 +59,18 @@ class DerivativeStructureField(
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin()
override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos()
override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan()
override fun sinh(arg: DerivativeStructure): DerivativeStructure = arg.sinh()
override fun cosh(arg: DerivativeStructure): DerivativeStructure = arg.cosh()
override fun tanh(arg: DerivativeStructure): DerivativeStructure = arg.tanh()
override fun asinh(arg: DerivativeStructure): DerivativeStructure = arg.asinh()
override fun acosh(arg: DerivativeStructure): DerivativeStructure = arg.acosh()
override fun atanh(arg: DerivativeStructure): DerivativeStructure = arg.atanh()
override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) {
is Double -> arg.pow(pow) is Double -> arg.pow(pow)
is Int -> arg.pow(pow) is Int -> arg.pow(pow)
@ -71,9 +78,7 @@ class DerivativeStructureField(
} }
fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow) fun power(arg: DerivativeStructure, pow: DerivativeStructure): DerivativeStructure = arg.pow(pow)
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())

View File

@ -139,15 +139,9 @@ open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> { ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg) override fun sin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.SIN_OPERATION, arg)
override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg) override fun cos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.COS_OPERATION, arg)
override fun asin(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ASIN_OPERATION, arg)
override fun asin(arg: Expression<T>): Expression<T> = override fun acos(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ACOS_OPERATION, arg)
unaryOperation(InverseTrigonometricOperations.ASIN_OPERATION, arg) override fun atan(arg: Expression<T>): Expression<T> = unaryOperation(TrigonometricOperations.ATAN_OPERATION, arg)
override fun acos(arg: Expression<T>): Expression<T> =
unaryOperation(InverseTrigonometricOperations.ACOS_OPERATION, arg)
override fun atan(arg: Expression<T>): Expression<T> =
unaryOperation(InverseTrigonometricOperations.ATAN_OPERATION, arg)
override fun power(arg: Expression<T>, pow: Number): Expression<T> = override fun power(arg: Expression<T>, pow: Number): Expression<T> =
binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow)) binaryOperation(PowerOperations.POW_OPERATION, arg, number(pow))

View File

@ -153,7 +153,7 @@ object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex> {
*/ */
operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this) operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
override fun norm(arg: Complex): Complex = arg.conjugate * arg override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value) override fun symbol(value: String): Complex = if (value == "i") i else super.symbol(value)
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
@ -7,19 +8,28 @@ import kotlin.math.pow as kpow
* Advanced Number-like semifield that implements basic operations. * Advanced Number-like semifield that implements basic operations.
*/ */
interface ExtendedFieldOperations<T> : interface ExtendedFieldOperations<T> :
InverseTrigonometricOperations<T>, FieldOperations<T>,
TrigonometricOperations<T>,
HyperbolicOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> { ExponentialOperations<T> {
override fun tan(arg: T): T = sin(arg) / cos(arg) override fun tan(arg: T): T = sin(arg) / cos(arg)
override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
override fun unaryOperation(operation: String, arg: T): T = when (operation) { override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg) TrigonometricOperations.TAN_OPERATION -> tan(arg)
InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) TrigonometricOperations.ACOS_OPERATION -> acos(arg)
InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) TrigonometricOperations.ASIN_OPERATION -> asin(arg)
InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) TrigonometricOperations.ATAN_OPERATION -> atan(arg)
HyperbolicOperations.COSH_OPERATION -> cosh(arg)
HyperbolicOperations.SINH_OPERATION -> sinh(arg)
HyperbolicOperations.TANH_OPERATION -> tanh(arg)
HyperbolicOperations.ACOSH_OPERATION -> acosh(arg)
HyperbolicOperations.ASINH_OPERATION -> asinh(arg)
HyperbolicOperations.ATANH_OPERATION -> atanh(arg)
PowerOperations.SQRT_OPERATION -> sqrt(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg)
ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.EXP_OPERATION -> exp(arg)
ExponentialOperations.LN_OPERATION -> ln(arg) ExponentialOperations.LN_OPERATION -> ln(arg)
@ -32,6 +42,13 @@ interface ExtendedFieldOperations<T> :
* Advanced Number-like field that implements basic operations. * Advanced Number-like field that implements basic operations.
*/ */
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> { interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right) PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right) else -> super.rightSideNumberOperation(operation, left, right)
@ -46,12 +63,13 @@ interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> { inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
override val context: RealField
get() = RealField
override fun unwrap(): Double = value override fun unwrap(): Double = value
override fun Double.wrap(): Real = Real(value) override fun Double.wrap(): Real = Real(value)
override val context: RealField get() = RealField
companion object companion object
} }
@ -60,12 +78,22 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object RealField : ExtendedField<Double>, Norm<Double, Double> { object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0 override val zero: Double
get() = 0.0
override val one: Double
get() = 1.0
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(a: Double, b: Double): Double = a + b
override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override val one: Double = 1.0 override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(a: Double, b: Double): Double = a / b
override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
@ -75,27 +103,24 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
override inline fun norm(arg: Double): Double = abs(arg) override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(b: Double): Double = this + b
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(b: Double): Double = this - b
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(b: Double): Double = this * b
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(b: Double): Double = this / b
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
} }
/** /**
@ -103,12 +128,22 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object FloatField : ExtendedField<Float>, Norm<Float, Float> { object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float = 0f override val zero: Float
get() = 0.0f
override val one: Float
get() = 1.0f
override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
override inline fun add(a: Float, b: Float): Float = a + b override inline fun add(a: Float, b: Float): Float = a + b
override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
override val one: Float = 1f override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun divide(a: Float, b: Float): Float = a / b override inline fun divide(a: Float, b: Float): Float = a / b
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
@ -118,108 +153,118 @@ object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
override inline fun power(arg: Float, pow: Number): Float = arg.pow(pow.toFloat()) override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
override inline fun norm(arg: Float): Float = abs(arg) override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float): Float = this + b override inline fun Float.plus(b: Float): Float = this + b
override inline fun Float.minus(b: Float): Float = this - b override inline fun Float.minus(b: Float): Float = this - b
override inline fun Float.times(b: Float): Float = this * b override inline fun Float.times(b: Float): Float = this * b
override inline fun Float.div(b: Float): Float = this / b override inline fun Float.div(b: Float): Float = this / b
} }
/** /**
* A field for [Int] without boxing. Does not produce corresponding field element * A field for [Int] without boxing. Does not produce corresponding ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object IntRing : Ring<Int>, Norm<Int, Int> { object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int = 0 override val zero: Int
get() = 0
override val one: Int
get() = 1
override inline fun add(a: Int, b: Int): Int = a + b override inline fun add(a: Int, b: Int): Int = a + b
override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
override val one: Int = 1
override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun norm(arg: Int): Int = abs(arg) override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b override inline fun Int.plus(b: Int): Int = this + b
override inline fun Int.minus(b: Int): Int = this - b override inline fun Int.minus(b: Int): Int = this - b
override inline fun Int.times(b: Int): Int = this * b override inline fun Int.times(b: Int): Int = this * b
} }
/** /**
* A field for [Short] without boxing. Does not produce appropriate field element * A field for [Short] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ShortRing : Ring<Short>, Norm<Short, Short> { object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short = 0 override val zero: Short
get() = 0
override val one: Short
get() = 1
override inline fun add(a: Short, b: Short): Short = (a + b).toShort() override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
override val one: Short = 1
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus(): Short = (-this).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short): Short = (this + b).toShort() override inline fun Short.plus(b: Short): Short = (this + b).toShort()
override inline fun Short.minus(b: Short): Short = (this - b).toShort() override inline fun Short.minus(b: Short): Short = (this - b).toShort()
override inline fun Short.times(b: Short): Short = (this * b).toShort() override inline fun Short.times(b: Short): Short = (this * b).toShort()
} }
/** /**
* A field for [Byte] values * A field for [Byte] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object ByteRing : Ring<Byte>, Norm<Byte, Byte> { object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte = 0 override val zero: Byte
get() = 0
override val one: Byte
get() = 1
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
override val one: Byte = 1
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus(): Byte = (-this).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
} }
/** /**
* A field for [Long] values * A field for [Double] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
object LongRing : Ring<Long>, Norm<Long, Long> { object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long = 0 override val zero: Long
override inline fun add(a: Long, b: Long): Long = (a + b) get() = 0
override inline fun multiply(a: Long, b: Long): Long = (a * b)
override val one: Long
get() = 1
override inline fun add(a: Long, b: Long): Long = a + b
override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
override val one: Long = 1
override inline fun multiply(a: Long, b: Long): Long = a * b
override fun norm(arg: Long): Long = abs(arg) override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus(): Long = (-this) override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long): Long = (this + b) override inline fun Long.plus(b: Long): Long = (this + b)
override inline fun Long.minus(b: Long): Long = (this - b) override inline fun Long.minus(b: Long): Long = (this - b)
override inline fun Long.times(b: Long): Long = (this * b) override inline fun Long.times(b: Long): Long = (this * b)
} }

View File

@ -1,12 +1,11 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/** /**
* A container for trigonometric operations for specific type. They are limited to semifields. * A container for trigonometric operations for specific type.
* *
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field. * @param T the type of element of this structure.
* It also allows to override behavior for optional operations.
*/ */
interface TrigonometricOperations<T> : FieldOperations<T> { interface TrigonometricOperations<T> : Algebra<T> {
/** /**
* Computes the sine of [arg]. * Computes the sine of [arg].
*/ */
@ -22,31 +21,6 @@ interface TrigonometricOperations<T> : FieldOperations<T> {
*/ */
fun tan(arg: T): T fun tan(arg: T): T
companion object {
/**
* The identifier of sine.
*/
const val SIN_OPERATION: String = "sin"
/**
* The identifier of cosine.
*/
const val COS_OPERATION: String = "cos"
/**
* The identifier of tangent.
*/
const val TAN_OPERATION: String = "tan"
}
}
/**
* A container for inverse trigonometric operations for specific type. They are limited to semifields.
*
* The operations are not exposed to class directly to avoid method bloat but instead are declared in the field.
* It also allows to override behavior for optional operations.
*/
interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
/** /**
* Computes the inverse sine of [arg]. * Computes the inverse sine of [arg].
*/ */
@ -63,6 +37,21 @@ interface InverseTrigonometricOperations<T> : TrigonometricOperations<T> {
fun atan(arg: T): T fun atan(arg: T): T
companion object { companion object {
/**
* The identifier of sine.
*/
const val SIN_OPERATION: String = "sin"
/**
* The identifier of cosine.
*/
const val COS_OPERATION: String = "cos"
/**
* The identifier of tangent.
*/
const val TAN_OPERATION: String = "tan"
/** /**
* The identifier of inverse sine. * The identifier of inverse sine.
*/ */
@ -98,20 +87,121 @@ fun <T : MathElement<out TrigonometricOperations<T>>> tan(arg: T): T = arg.conte
/** /**
* Computes the inverse sine of [arg]. * Computes the inverse sine of [arg].
*/ */
fun <T : MathElement<out InverseTrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg) fun <T : MathElement<out TrigonometricOperations<T>>> asin(arg: T): T = arg.context.asin(arg)
/** /**
* Computes the inverse cosine of [arg]. * Computes the inverse cosine of [arg].
*/ */
fun <T : MathElement<out InverseTrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg) fun <T : MathElement<out TrigonometricOperations<T>>> acos(arg: T): T = arg.context.acos(arg)
/** /**
* Computes the inverse tangent of [arg]. * Computes the inverse tangent of [arg].
*/ */
fun <T : MathElement<out InverseTrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg) fun <T : MathElement<out TrigonometricOperations<T>>> atan(arg: T): T = arg.context.atan(arg)
/**
* A container for hyperbolic trigonometric operations for specific type.
*
* @param T the type of element of this structure.
*/
interface HyperbolicOperations<T> : Algebra<T> {
/**
* Computes the hyperbolic sine of [arg].
*/
fun sinh(arg: T): T
/**
* Computes the hyperbolic cosine of [arg].
*/
fun cosh(arg: T): T
/**
* Computes the hyperbolic tangent of [arg].
*/
fun tanh(arg: T): T
/**
* Computes the inverse hyperbolic sine of [arg].
*/
fun asinh(arg: T): T
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
fun acosh(arg: T): T
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
fun atanh(arg: T): T
companion object {
/**
* The identifier of hyperbolic sine.
*/
const val SINH_OPERATION: String = "sinh"
/**
* The identifier of hyperbolic cosine.
*/
const val COSH_OPERATION: String = "cosh"
/**
* The identifier of hyperbolic tangent.
*/
const val TANH_OPERATION: String = "tanh"
/**
* The identifier of inverse hyperbolic sine.
*/
const val ASINH_OPERATION: String = "asinh"
/**
* The identifier of inverse hyperbolic cosine.
*/
const val ACOSH_OPERATION: String = "acosh"
/**
* The identifier of inverse hyperbolic tangent.
*/
const val ATANH_OPERATION: String = "atanh"
}
}
/**
* Computes the hyperbolic sine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> sinh(arg: T): T = arg.context.sinh(arg)
/**
* Computes the hyperbolic cosine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> cosh(arg: T): T = arg.context.cosh(arg)
/**
* Computes the hyperbolic tangent of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> tanh(arg: T): T = arg.context.tanh(arg)
/**
* Computes the inverse hyperbolic sine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> asinh(arg: T): T = arg.context.asinh(arg)
/**
* Computes the inverse hyperbolic cosine of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> acosh(arg: T): T = arg.context.acosh(arg)
/**
* Computes the inverse hyperbolic tangent of [arg].
*/
fun <T : MathElement<out HyperbolicOperations<T>>> atanh(arg: T): T = arg.context.atanh(arg)
/** /**
* A context extension to include power operations based on exponentiation. * A context extension to include power operations based on exponentiation.
*
* @param T the type of element of this structure.
*/ */
interface PowerOperations<T> : Algebra<T> { interface PowerOperations<T> : Algebra<T> {
/** /**
@ -163,6 +253,8 @@ fun <T : MathElement<out PowerOperations<T>>> sqr(arg: T): T = arg pow 2.0
/** /**
* A container for operations related to `exp` and `ln` functions. * A container for operations related to `exp` and `ln` functions.
*
* @param T the type of element of this structure.
*/ */
interface ExponentialOperations<T> : Algebra<T> { interface ExponentialOperations<T> : Algebra<T> {
/** /**
@ -200,6 +292,9 @@ fun <T : MathElement<out ExponentialOperations<T>>> ln(arg: T): T = arg.context.
/** /**
* A container for norm functional on element. * A container for norm functional on element.
*
* @param T the type of element having norm defined.
* @param R the type of norm.
*/ */
interface Norm<in T : Any, out R> { interface Norm<in T : Any, out R> {
/** /**

View File

@ -15,7 +15,6 @@ class ComplexNDField(override val shape: IntArray) :
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> { ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
override val strides: Strides = DefaultStrides(shape) override val strides: Strides = DefaultStrides(shape)
override val elementContext: ComplexField get() = ComplexField override val elementContext: ComplexField get() = ComplexField
override val zero: ComplexNDElement by lazy { produce { zero } } override val zero: ComplexNDElement by lazy { produce { zero } }
override val one: ComplexNDElement by lazy { produce { one } } override val one: ComplexNDElement by lazy { produce { one } }
@ -45,6 +44,7 @@ class ComplexNDField(override val shape: IntArray) :
transform: ComplexField.(index: IntArray, Complex) -> Complex transform: ComplexField.(index: IntArray, Complex) -> Complex
): ComplexNDElement { ): ComplexNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -61,6 +61,7 @@ class ComplexNDField(override val shape: IntArray) :
transform: ComplexField.(Complex, Complex) -> Complex transform: ComplexField.(Complex, Complex) -> Complex
): ComplexNDElement { ): ComplexNDElement {
check(a, b) check(a, b)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
@ -69,23 +70,25 @@ class ComplexNDField(override val shape: IntArray) :
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> = override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
BufferedNDFieldElement(this@ComplexNDField, buffer) BufferedNDFieldElement(this@ComplexNDField, buffer)
override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement = map(arg) { power(it, pow) } override fun power(arg: NDBuffer<Complex>, pow: Number): ComplexNDElement =
map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) } override fun exp(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) } override fun tan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) } override fun asin(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) } override fun acos(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) } override fun atan(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atan(it) }
override fun sinh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { sinh(it) }
override fun cosh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { cosh(it) }
override fun tanh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { tanh(it) }
override fun asinh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { asinh(it) }
override fun acosh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { acosh(it) }
override fun atanh(arg: NDBuffer<Complex>): ComplexNDElement = map(arg) { atanh(it) }
} }

View File

@ -10,14 +10,15 @@ import kotlin.math.*
*/ */
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> { object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
} else } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
} }
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer { override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
@ -26,12 +27,13 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
return if (a is RealBuffer) { return if (a is RealBuffer) {
val aArray = a.array val aArray = a.array
RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) RealBuffer(DoubleArray(a.size) { aArray[it] * kValue })
} else } else RealBuffer(DoubleArray(a.size) { a[it] * kValue })
RealBuffer(DoubleArray(a.size) { a[it] * kValue })
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
@ -42,34 +44,31 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
}
return if (a is RealBuffer && b is RealBuffer) { return if (a is RealBuffer && b is RealBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
} else } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
} }
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
} else { } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
}
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
@ -90,23 +89,50 @@ object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
} else } else
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
} else } else RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else } else RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
RealBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} }
/** /**
@ -168,6 +194,36 @@ class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
return RealBufferFieldOperations.atan(arg) return RealBufferFieldOperations.atan(arg)
} }
override fun sinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sinh(arg)
}
override fun cosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cosh(arg)
}
override fun tanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.tanh(arg)
}
override fun asinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.asinh(arg)
}
override fun acosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.acosh(arg)
}
override fun atanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.atanh(arg)
}
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer { override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.power(arg, pow) return RealBufferFieldOperations.power(arg, pow)

View File

@ -40,6 +40,7 @@ class RealNDField(override val shape: IntArray) :
transform: RealField.(index: IntArray, Double) -> Double transform: RealField.(index: IntArray, Double) -> Double
): RealNDElement { ): RealNDElement {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
this, this,
buildBuffer(arg.strides.linearSize) { offset -> buildBuffer(arg.strides.linearSize) { offset ->
@ -71,16 +72,18 @@ class RealNDField(override val shape: IntArray) :
override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) } override fun ln(arg: NDBuffer<Double>): RealNDElement = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) } override fun sin(arg: NDBuffer<Double>): RealNDElement = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) } override fun cos(arg: NDBuffer<Double>): RealNDElement = map(arg) { cos(it) }
override fun tan(arg: NDBuffer<Double>): RealNDElement = map(arg) { tan(it) }
override fun asin(arg: NDBuffer<Double>): RealNDElement = map(arg) { asin(it) }
override fun acos(arg: NDBuffer<Double>): RealNDElement = map(arg) { acos(it) }
override fun atan(arg: NDBuffer<Double>): RealNDElement = map(arg) { atan(it) }
override fun tan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { tan(it) } override fun sinh(arg: NDBuffer<Double>): RealNDElement = map(arg) { sinh(it) }
override fun cosh(arg: NDBuffer<Double>): RealNDElement = map(arg) { cosh(it) }
override fun asin(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { asin(it) } override fun tanh(arg: NDBuffer<Double>): RealNDElement = map(arg) { tanh(it) }
override fun asinh(arg: NDBuffer<Double>): RealNDElement = map(arg) { asinh(it) }
override fun acos(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { acos(it) } override fun acosh(arg: NDBuffer<Double>): RealNDElement = map(arg) { acosh(it) }
override fun atanh(arg: NDBuffer<Double>): RealNDElement = map(arg) { atanh(it) }
override fun atan(arg: NDBuffer<Double>): NDBuffer<Double> = map(arg) { atan(it) }
} }
@ -130,6 +133,5 @@ operator fun RealNDElement.minus(arg: Double): RealNDElement =
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R {
return NDField.real(*shape).run(action) inline fun <R> RealField.nd(vararg shape: Int, action: RealNDField.() -> R): R = NDField.real(*shape).run(action)
}

View File

@ -1,7 +1,10 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import kotlin.math.PI
import kotlin.math.abs
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class ComplexFieldTest { internal class ComplexFieldTest {
@Test @Test
@ -34,6 +37,25 @@ internal class ComplexFieldTest {
assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) }) assertEquals(Complex(Double.NaN, Double.NaN), ComplexField { Complex(1, 1) / Complex(0, 0) })
} }
@Test
fun testSine() {
assertEquals(ComplexField { i * sinh(one) }, ComplexField { sin(i) })
assertEquals(ComplexField { i * sinh(PI.toComplex()) }, ComplexField { sin(i * PI.toComplex()) })
}
@Test
fun testInverseSine() {
assertEquals(Complex(0, -0.0), ComplexField { asin(zero) })
assertTrue(abs(ComplexField { i * asinh(one) }.r - ComplexField { asin(i) }.r) < 0.000000000000001)
}
@Test
fun testInverseHyperbolicSine() {
assertEquals(
ComplexField { i * PI.toComplex() / 2 },
ComplexField { asinh(i) })
}
@Test @Test
fun testPower() { fun testPower() {
assertEquals(ComplexField.zero, ComplexField { zero pow 2 }) assertEquals(ComplexField.zero, ComplexField { zero pow 2 })
@ -43,4 +65,9 @@ internal class ComplexFieldTest {
ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() }, ComplexField { i * 8 }.let { it.im.toInt() to it.re.toInt() },
ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() }) ComplexField { Complex(2, 2) pow 2 }.let { it.im.toInt() to it.re.toInt() })
} }
@Test
fun testNorm() {
assertEquals(2.toComplex(), ComplexField { norm(2 * i) })
}
} }