Specify explicit API visbility, minor refactoring (error handling, etc.)
This commit is contained in:
parent
6b79e79d21
commit
f567f73d19
@ -20,7 +20,7 @@ repositories {
|
||||
sourceSets.register("benchmarks")
|
||||
|
||||
dependencies {
|
||||
implementation(project(":kmath-ast"))
|
||||
// implementation(project(":kmath-ast"))
|
||||
implementation(project(":kmath-core"))
|
||||
implementation(project(":kmath-coroutines"))
|
||||
implementation(project(":kmath-commons"))
|
||||
|
@ -1,6 +1,5 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
import scientifik.kmath.operations.RealField.pow
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow as kpow
|
||||
|
||||
@ -13,11 +12,10 @@ public interface ExtendedFieldOperations<T> :
|
||||
HyperbolicOperations<T>,
|
||||
PowerOperations<T>,
|
||||
ExponentialOperations<T> {
|
||||
public override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
|
||||
|
||||
override fun tan(arg: T): T = sin(arg) / cos(arg)
|
||||
override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
|
||||
|
||||
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
|
||||
TrigonometricOperations.COS_OPERATION -> cos(arg)
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(arg)
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(arg)
|
||||
@ -37,19 +35,18 @@ public interface ExtendedFieldOperations<T> :
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations.
|
||||
*/
|
||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
|
||||
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
|
||||
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
||||
override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
|
||||
override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
|
||||
override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
|
||||
public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
|
||||
public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
|
||||
public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
||||
public override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
|
||||
public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
|
||||
public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
|
||||
|
||||
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> power(left, right)
|
||||
else -> super.rightSideNumberOperation(operation, left, right)
|
||||
}
|
||||
@ -63,12 +60,11 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
||||
*/
|
||||
public inline class Real(public val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
override val context: RealField
|
||||
public override val context: RealField
|
||||
get() = RealField
|
||||
|
||||
override fun unwrap(): Double = value
|
||||
|
||||
override fun Double.wrap(): Real = Real(value)
|
||||
public override fun unwrap(): Double = value
|
||||
public override fun Double.wrap(): Real = Real(value)
|
||||
|
||||
public companion object
|
||||
}
|
||||
@ -78,49 +74,49 @@ public inline class Real(public val value: Double) : FieldElement<Double, Real,
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
override val zero: Double
|
||||
public override val zero: Double
|
||||
get() = 0.0
|
||||
|
||||
override val one: Double
|
||||
public override val one: Double
|
||||
get() = 1.0
|
||||
|
||||
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
||||
public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> left pow right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
override inline fun add(a: Double, b: Double): Double = a + b
|
||||
override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||
public override inline fun add(a: Double, b: Double): Double = a + b
|
||||
public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
|
||||
|
||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
||||
public override inline fun multiply(a: Double, b: Double): Double = a * b
|
||||
|
||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
||||
public override inline fun divide(a: Double, b: Double): Double = a / b
|
||||
|
||||
override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
||||
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
|
||||
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
|
||||
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
|
||||
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
|
||||
public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
|
||||
public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
|
||||
public override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
|
||||
public override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
|
||||
public override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
|
||||
public override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
|
||||
|
||||
override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
|
||||
override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
|
||||
override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
|
||||
override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
|
||||
override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
|
||||
override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
|
||||
public override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
|
||||
public override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
|
||||
public override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
|
||||
public override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
|
||||
public override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
|
||||
public override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
|
||||
|
||||
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
|
||||
override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
||||
override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
||||
public override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
|
||||
public override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
|
||||
public override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
|
||||
|
||||
override inline fun norm(arg: Double): Double = abs(arg)
|
||||
public override inline fun norm(arg: Double): Double = abs(arg)
|
||||
|
||||
override inline fun Double.unaryMinus(): Double = -this
|
||||
override inline fun Double.plus(b: Double): Double = this + b
|
||||
override inline fun Double.minus(b: Double): Double = this - b
|
||||
override inline fun Double.times(b: Double): Double = this * b
|
||||
override inline fun Double.div(b: Double): Double = this / b
|
||||
public override inline fun Double.unaryMinus(): Double = -this
|
||||
public override inline fun Double.plus(b: Double): Double = this + b
|
||||
public override inline fun Double.minus(b: Double): Double = this - b
|
||||
public override inline fun Double.times(b: Double): Double = this * b
|
||||
public override inline fun Double.div(b: Double): Double = this / b
|
||||
}
|
||||
|
||||
/**
|
||||
@ -128,49 +124,49 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
override val zero: Float
|
||||
public override val zero: Float
|
||||
get() = 0.0f
|
||||
|
||||
override val one: Float
|
||||
public override val one: Float
|
||||
get() = 1.0f
|
||||
|
||||
override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
|
||||
public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> left pow right
|
||||
else -> super.binaryOperation(operation, left, right)
|
||||
}
|
||||
|
||||
override inline fun add(a: Float, b: Float): Float = a + b
|
||||
override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
|
||||
public override inline fun add(a: Float, b: Float): Float = a + b
|
||||
public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
|
||||
|
||||
override inline fun multiply(a: Float, b: Float): Float = a * b
|
||||
public override inline fun multiply(a: Float, b: Float): Float = a * b
|
||||
|
||||
override inline fun divide(a: Float, b: Float): Float = a / b
|
||||
public override inline fun divide(a: Float, b: Float): Float = a / b
|
||||
|
||||
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
|
||||
override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
|
||||
override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
|
||||
override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
|
||||
override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
|
||||
public override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
|
||||
public override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
|
||||
public override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
|
||||
public override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
|
||||
public override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
|
||||
public override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
|
||||
|
||||
override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
|
||||
override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
|
||||
override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
|
||||
override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
|
||||
override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
|
||||
override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
|
||||
public override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
|
||||
public override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
|
||||
public override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
|
||||
public override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
|
||||
public override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
|
||||
public override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
|
||||
|
||||
override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
|
||||
override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
|
||||
override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
|
||||
public override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
|
||||
public override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
|
||||
public override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
|
||||
|
||||
override inline fun norm(arg: Float): Float = abs(arg)
|
||||
public override inline fun norm(arg: Float): Float = abs(arg)
|
||||
|
||||
override inline fun Float.unaryMinus(): Float = -this
|
||||
override inline fun Float.plus(b: Float): Float = this + b
|
||||
override inline fun Float.minus(b: Float): Float = this - b
|
||||
override inline fun Float.times(b: Float): Float = this * b
|
||||
override inline fun Float.div(b: Float): Float = this / b
|
||||
public override inline fun Float.unaryMinus(): Float = -this
|
||||
public override inline fun Float.plus(b: Float): Float = this + b
|
||||
public override inline fun Float.minus(b: Float): Float = this - b
|
||||
public override inline fun Float.times(b: Float): Float = this * b
|
||||
public override inline fun Float.div(b: Float): Float = this / b
|
||||
}
|
||||
|
||||
/**
|
||||
@ -178,23 +174,23 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
override val zero: Int
|
||||
public override val zero: Int
|
||||
get() = 0
|
||||
|
||||
override val one: Int
|
||||
public override val one: Int
|
||||
get() = 1
|
||||
|
||||
override inline fun add(a: Int, b: Int): Int = a + b
|
||||
override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
|
||||
public override inline fun add(a: Int, b: Int): Int = a + b
|
||||
public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
|
||||
|
||||
override inline fun multiply(a: Int, b: Int): Int = a * b
|
||||
public override inline fun multiply(a: Int, b: Int): Int = a * b
|
||||
|
||||
override inline fun norm(arg: Int): Int = abs(arg)
|
||||
public override inline fun norm(arg: Int): Int = abs(arg)
|
||||
|
||||
override inline fun Int.unaryMinus(): Int = -this
|
||||
override inline fun Int.plus(b: Int): Int = this + b
|
||||
override inline fun Int.minus(b: Int): Int = this - b
|
||||
override inline fun Int.times(b: Int): Int = this * b
|
||||
public override inline fun Int.unaryMinus(): Int = -this
|
||||
public override inline fun Int.plus(b: Int): Int = this + b
|
||||
public override inline fun Int.minus(b: Int): Int = this - b
|
||||
public override inline fun Int.times(b: Int): Int = this * b
|
||||
}
|
||||
|
||||
/**
|
||||
@ -202,23 +198,23 @@ public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
override val zero: Short
|
||||
public override val zero: Short
|
||||
get() = 0
|
||||
|
||||
override val one: Short
|
||||
public override val one: Short
|
||||
get() = 1
|
||||
|
||||
override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
|
||||
public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
|
||||
|
||||
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||
public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||
|
||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||
public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||
|
||||
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
||||
override inline fun Short.plus(b: Short): Short = (this + b).toShort()
|
||||
override inline fun Short.minus(b: Short): Short = (this - b).toShort()
|
||||
override inline fun Short.times(b: Short): Short = (this * b).toShort()
|
||||
public override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
||||
public override inline fun Short.plus(b: Short): Short = (this + b).toShort()
|
||||
public override inline fun Short.minus(b: Short): Short = (this - b).toShort()
|
||||
public override inline fun Short.times(b: Short): Short = (this * b).toShort()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -226,23 +222,23 @@ public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
override val zero: Byte
|
||||
public override val zero: Byte
|
||||
get() = 0
|
||||
|
||||
override val one: Byte
|
||||
public override val one: Byte
|
||||
get() = 1
|
||||
|
||||
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
|
||||
public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
|
||||
|
||||
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||
public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||
|
||||
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||
public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||
|
||||
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
||||
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
|
||||
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
|
||||
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
|
||||
public override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
||||
public override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
|
||||
public override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
|
||||
public override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -250,21 +246,21 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||
override val zero: Long
|
||||
public override val zero: Long
|
||||
get() = 0
|
||||
|
||||
override val one: Long
|
||||
public override val one: Long
|
||||
get() = 1
|
||||
|
||||
override inline fun add(a: Long, b: Long): Long = a + b
|
||||
override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
|
||||
public override inline fun add(a: Long, b: Long): Long = a + b
|
||||
public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
|
||||
|
||||
override inline fun multiply(a: Long, b: Long): Long = a * b
|
||||
public override inline fun multiply(a: Long, b: Long): Long = a * b
|
||||
|
||||
override fun norm(arg: Long): Long = abs(arg)
|
||||
public override fun norm(arg: Long): Long = abs(arg)
|
||||
|
||||
override inline fun Long.unaryMinus(): Long = (-this)
|
||||
override inline fun Long.plus(b: Long): Long = (this + b)
|
||||
override inline fun Long.minus(b: Long): Long = (this - b)
|
||||
override inline fun Long.times(b: Long): Long = (this * b)
|
||||
public override inline fun Long.unaryMinus(): Long = (-this)
|
||||
public override inline fun Long.plus(b: Long): Long = (this + b)
|
||||
public override inline fun Long.minus(b: Long): Long = (this - b)
|
||||
public override inline fun Long.times(b: Long): Long = (this * b)
|
||||
}
|
||||
|
@ -4,27 +4,27 @@ import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
public class BoxingNDField<T, F : Field<T>>(
|
||||
override val shape: IntArray,
|
||||
override val elementContext: F,
|
||||
public override val shape: IntArray,
|
||||
public override val elementContext: F,
|
||||
public val bufferFactory: BufferFactory<T>
|
||||
) : BufferedNDField<T, F> {
|
||||
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
public override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
|
||||
public override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
|
||||
public override val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||
bufferFactory(size, initializer)
|
||||
|
||||
override fun check(vararg elements: NDBuffer<T>) {
|
||||
public override fun check(vararg elements: NDBuffer<T>) {
|
||||
check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
||||
}
|
||||
|
||||
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||
public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||
BufferedNDFieldElement(
|
||||
this,
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
|
||||
|
||||
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
|
||||
public override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
|
||||
check(arg)
|
||||
|
||||
return BufferedNDFieldElement(
|
||||
@ -36,7 +36,7 @@ public class BoxingNDField<T, F : Field<T>>(
|
||||
|
||||
}
|
||||
|
||||
override fun mapIndexed(
|
||||
public override fun mapIndexed(
|
||||
arg: NDBuffer<T>,
|
||||
transform: F.(index: IntArray, T) -> T
|
||||
): BufferedNDFieldElement<T, F> {
|
||||
@ -55,7 +55,7 @@ public class BoxingNDField<T, F : Field<T>>(
|
||||
// return BufferedNDFieldElement(this, buffer)
|
||||
}
|
||||
|
||||
override fun combine(
|
||||
public override fun combine(
|
||||
a: NDBuffer<T>,
|
||||
b: NDBuffer<T>,
|
||||
transform: F.(T, T) -> T
|
||||
@ -66,7 +66,7 @@ public class BoxingNDField<T, F : Field<T>>(
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
||||
}
|
||||
|
||||
override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> =
|
||||
public override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> =
|
||||
BufferedNDFieldElement(this@BoxingNDField, buffer)
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,6 @@ public inline class LongBuffer(public val array: LongArray) : MutableBuffer<Long
|
||||
|
||||
override fun copy(): MutableBuffer<Long> =
|
||||
LongBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -8,7 +8,7 @@ import kotlin.math.*
|
||||
* [ExtendedFieldOperations] over [RealBuffer].
|
||||
*/
|
||||
public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
public override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(b.size == a.size) {
|
||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||
}
|
||||
@ -20,7 +20,7 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
|
||||
} else RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||
public override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||
val kValue = k.toDouble()
|
||||
|
||||
return if (a is RealBuffer) {
|
||||
@ -29,7 +29,7 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
|
||||
} else RealBuffer(DoubleArray(a.size) { a[it] * kValue })
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
public override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(b.size == a.size) {
|
||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||
}
|
||||
@ -42,7 +42,7 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
|
||||
RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
}
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
public override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(b.size == a.size) {
|
||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||
}
|
||||
@ -54,87 +54,87 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
|
||||
} else RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
||||
} else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
||||
|
||||
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
||||
} else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
||||
|
||||
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
|
||||
} else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
|
||||
|
||||
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
|
||||
|
||||
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { acos(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { acos(arg[it]) })
|
||||
|
||||
override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { atan(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
|
||||
|
||||
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
|
||||
|
||||
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
|
||||
|
||||
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
|
||||
|
||||
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
|
||||
|
||||
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
|
||||
|
||||
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||
} else
|
||||
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||
|
||||
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
||||
} else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
||||
|
||||
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
public override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
|
||||
val array = arg.array
|
||||
RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||
} else
|
||||
@ -147,100 +147,100 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
|
||||
* @property size the size of buffers to operate on.
|
||||
*/
|
||||
public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double>> {
|
||||
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||
public override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||
public override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
public override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.add(a, b)
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||
public override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.multiply(a, k)
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
public override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.multiply(a, b)
|
||||
}
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
public override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.divide(a, b)
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun sin(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.sin(arg)
|
||||
}
|
||||
|
||||
override fun cos(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun cos(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.cos(arg)
|
||||
}
|
||||
|
||||
override fun tan(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun tan(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.tan(arg)
|
||||
}
|
||||
|
||||
override fun asin(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun asin(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.asin(arg)
|
||||
}
|
||||
|
||||
override fun acos(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun acos(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.acos(arg)
|
||||
}
|
||||
|
||||
override fun atan(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun atan(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.atan(arg)
|
||||
}
|
||||
|
||||
override fun sinh(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun sinh(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.sinh(arg)
|
||||
}
|
||||
|
||||
override fun cosh(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun cosh(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.cosh(arg)
|
||||
}
|
||||
|
||||
override fun tanh(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun tanh(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.tanh(arg)
|
||||
}
|
||||
|
||||
override fun asinh(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun asinh(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.asinh(arg)
|
||||
}
|
||||
|
||||
override fun acosh(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun acosh(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.acosh(arg)
|
||||
}
|
||||
|
||||
override fun atanh(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun atanh(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.atanh(arg)
|
||||
}
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
|
||||
public override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.power(arg, pow)
|
||||
}
|
||||
|
||||
override fun exp(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun exp(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.exp(arg)
|
||||
}
|
||||
|
||||
override fun ln(arg: Buffer<Double>): RealBuffer {
|
||||
public override fun ln(arg: Buffer<Double>): RealBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.ln(arg)
|
||||
}
|
||||
|
@ -1,25 +1,21 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import kotlin.contracts.contract
|
||||
|
||||
/**
|
||||
* Specialized [MutableBuffer] implementation over [ShortArray].
|
||||
*
|
||||
* @property array the underlying array.
|
||||
*/
|
||||
public inline class ShortBuffer(public val array: ShortArray) : MutableBuffer<Short> {
|
||||
override val size: Int get() = array.size
|
||||
public override val size: Int get() = array.size
|
||||
|
||||
override operator fun get(index: Int): Short = array[index]
|
||||
public override operator fun get(index: Int): Short = array[index]
|
||||
|
||||
override operator fun set(index: Int, value: Short) {
|
||||
public override operator fun set(index: Int, value: Short) {
|
||||
array[index] = value
|
||||
}
|
||||
|
||||
override operator fun iterator(): ShortIterator = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Short> =
|
||||
ShortBuffer(array.copyOf())
|
||||
public override operator fun iterator(): ShortIterator = array.iterator()
|
||||
public override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4,25 +4,24 @@ package scientifik.kmath.structures
|
||||
* A structure that is guaranteed to be one-dimensional
|
||||
*/
|
||||
public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
|
||||
override val dimension: Int get() = 1
|
||||
public override val dimension: Int get() = 1
|
||||
|
||||
override operator fun get(index: IntArray): T {
|
||||
public override operator fun get(index: IntArray): T {
|
||||
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
|
||||
return get(index[0])
|
||||
}
|
||||
|
||||
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
|
||||
public override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
|
||||
}
|
||||
|
||||
/**
|
||||
* A 1D wrapper for nd-structure
|
||||
*/
|
||||
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> {
|
||||
private inline class Structure1DWrapper<T>(public val structure: NDStructure<T>) : Structure1D<T> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val size: Int get() = structure.shape[0]
|
||||
|
||||
override operator fun get(index: Int): T = structure[index]
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
}
|
||||
|
||||
@ -32,7 +31,6 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
|
||||
*/
|
||||
private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
|
||||
override val shape: IntArray get() = intArrayOf(buffer.size)
|
||||
|
||||
override val size: Int get() = buffer.size
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
|
@ -8,19 +8,19 @@ import java.math.MathContext
|
||||
* A field over [BigInteger].
|
||||
*/
|
||||
public object JBigIntegerField : Field<BigInteger> {
|
||||
override val zero: BigInteger
|
||||
public override val zero: BigInteger
|
||||
get() = BigInteger.ZERO
|
||||
|
||||
override val one: BigInteger
|
||||
public override val one: BigInteger
|
||||
get() = BigInteger.ONE
|
||||
|
||||
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
||||
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
||||
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
|
||||
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||
override operator fun BigInteger.unaryMinus(): BigInteger = negate()
|
||||
public override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
||||
public override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
|
||||
public override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||
public override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
|
||||
public override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
|
||||
public override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||
public override operator fun BigInteger.unaryMinus(): BigInteger = negate()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -31,24 +31,24 @@ public object JBigIntegerField : Field<BigInteger> {
|
||||
public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) :
|
||||
Field<BigDecimal>,
|
||||
PowerOperations<BigDecimal> {
|
||||
override val zero: BigDecimal
|
||||
public override val zero: BigDecimal
|
||||
get() = BigDecimal.ZERO
|
||||
|
||||
override val one: BigDecimal
|
||||
public override val one: BigDecimal
|
||||
get() = BigDecimal.ONE
|
||||
|
||||
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
||||
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
||||
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
||||
public override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
||||
public override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
||||
public override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
||||
|
||||
override fun multiply(a: BigDecimal, k: Number): BigDecimal =
|
||||
public override fun multiply(a: BigDecimal, k: Number): BigDecimal =
|
||||
a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext)
|
||||
|
||||
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
||||
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
||||
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
||||
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
||||
public override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
||||
public override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||
public override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
||||
public override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
||||
public override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3,10 +3,10 @@ package scientifik.kmath.chains
|
||||
/**
|
||||
* Performance optimized chain for integer values
|
||||
*/
|
||||
abstract class BlockingIntChain : Chain<Int> {
|
||||
abstract fun nextInt(): Int
|
||||
public abstract class BlockingIntChain : Chain<Int> {
|
||||
public abstract fun nextInt(): Int
|
||||
|
||||
override suspend fun next(): Int = nextInt()
|
||||
|
||||
fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
|
||||
public fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
|
||||
}
|
||||
|
@ -3,10 +3,10 @@ package scientifik.kmath.chains
|
||||
/**
|
||||
* Performance optimized chain for real values
|
||||
*/
|
||||
abstract class BlockingRealChain : Chain<Double> {
|
||||
abstract fun nextDouble(): Double
|
||||
public abstract class BlockingRealChain : Chain<Double> {
|
||||
public abstract fun nextDouble(): Double
|
||||
|
||||
override suspend fun next(): Double = nextDouble()
|
||||
|
||||
fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
|
||||
public fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
|
||||
}
|
||||
|
@ -3,20 +3,19 @@ package scientifik.kmath.chains
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.runningReduce
|
||||
import kotlinx.coroutines.flow.scan
|
||||
import kotlinx.coroutines.flow.scanReduce
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.operations.SpaceOperations
|
||||
import scientifik.kmath.operations.invoke
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> = space {
|
||||
scanReduce { sum: T, element: T -> sum + element }
|
||||
}
|
||||
public fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> =
|
||||
space { runningReduce { sum, element -> sum + element } }
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = space {
|
||||
class Accumulator(var sum: T, var num: Int)
|
||||
public fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = space {
|
||||
data class Accumulator(var sum: T, var num: Int)
|
||||
|
||||
scan(Accumulator(zero, 0)) { sum, element ->
|
||||
sum.apply {
|
||||
|
@ -11,18 +11,18 @@ import scientifik.kmath.structures.asBuffer
|
||||
/**
|
||||
* Create a [Flow] from buffer
|
||||
*/
|
||||
fun <T> Buffer<T>.asFlow(): Flow<T> = iterator().asFlow()
|
||||
public fun <T> Buffer<T>.asFlow(): Flow<T> = iterator().asFlow()
|
||||
|
||||
/**
|
||||
* Flat map a [Flow] of [Buffer] into continuous [Flow] of elements
|
||||
*/
|
||||
@FlowPreview
|
||||
fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
|
||||
public fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
|
||||
|
||||
/**
|
||||
* Collect incoming flow into fixed size chunks
|
||||
*/
|
||||
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<Buffer<T>> = flow {
|
||||
public fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<Buffer<T>> = flow {
|
||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||
val list = ArrayList<T>(bufferSize)
|
||||
var counter = 0
|
||||
@ -30,6 +30,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
|
||||
this@chunked.collect { element ->
|
||||
list.add(element)
|
||||
counter++
|
||||
|
||||
if (counter == bufferSize) {
|
||||
val buffer = bufferFactory(bufferSize) { list[it] }
|
||||
emit(buffer)
|
||||
@ -37,15 +38,14 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
|
||||
counter = 0
|
||||
}
|
||||
}
|
||||
if (counter > 0) {
|
||||
emit(bufferFactory(counter) { list[it] })
|
||||
}
|
||||
|
||||
if (counter > 0) emit(bufferFactory(counter) { list[it] })
|
||||
}
|
||||
|
||||
/**
|
||||
* Specialized flow chunker for real buffer
|
||||
*/
|
||||
fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
|
||||
public fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
|
||||
require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
|
||||
|
||||
if (this@chunked is BlockingRealChain) {
|
||||
@ -66,9 +66,8 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
|
||||
counter = 0
|
||||
}
|
||||
}
|
||||
if (counter > 0) {
|
||||
emit(RealBuffer(counter) { array[it] })
|
||||
}
|
||||
|
||||
if (counter > 0) emit(RealBuffer(counter) { array[it] })
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,9 +75,10 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
|
||||
* Map a flow to a moving window buffer. The window step is one.
|
||||
* In order to get different steps, one could use skip operation.
|
||||
*/
|
||||
fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow {
|
||||
public fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow {
|
||||
require(window > 1) { "Window size must be more than one" }
|
||||
val ringBuffer = RingBuffer.boxing<T>(window)
|
||||
|
||||
this@windowed.collect { element ->
|
||||
ringBuffer.push(element)
|
||||
emit(ringBuffer.snapshot())
|
||||
|
@ -10,28 +10,28 @@ import scientifik.kmath.structures.VirtualBuffer
|
||||
* Thread-safe ring buffer
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
class RingBuffer<T>(
|
||||
public class RingBuffer<T>(
|
||||
private val buffer: MutableBuffer<T?>,
|
||||
private var startIndex: Int = 0,
|
||||
size: Int = 0
|
||||
) : Buffer<T> {
|
||||
private val mutex: Mutex = Mutex()
|
||||
|
||||
override var size: Int = size
|
||||
public override var size: Int = size
|
||||
private set
|
||||
|
||||
override operator fun get(index: Int): T {
|
||||
public override operator fun get(index: Int): T {
|
||||
require(index >= 0) { "Index must be positive" }
|
||||
require(index < size) { "Index $index is out of circular buffer size $size" }
|
||||
return buffer[startIndex.forward(index)] as T
|
||||
}
|
||||
|
||||
fun isFull(): Boolean = size == buffer.size
|
||||
public fun isFull(): Boolean = size == buffer.size
|
||||
|
||||
/**
|
||||
* Iterator could provide wrong results if buffer is changed in initialization (iteration is safe)
|
||||
*/
|
||||
override operator fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
|
||||
public override operator fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
|
||||
private var count = size
|
||||
private var index = startIndex
|
||||
val copy = buffer.copy()
|
||||
@ -48,23 +48,17 @@ class RingBuffer<T>(
|
||||
/**
|
||||
* A safe snapshot operation
|
||||
*/
|
||||
suspend fun snapshot(): Buffer<T> {
|
||||
public suspend fun snapshot(): Buffer<T> {
|
||||
mutex.withLock {
|
||||
val copy = buffer.copy()
|
||||
return VirtualBuffer(size) { i ->
|
||||
copy[startIndex.forward(i)] as T
|
||||
}
|
||||
return VirtualBuffer(size) { i -> copy[startIndex.forward(i)] as T }
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun push(element: T) {
|
||||
public suspend fun push(element: T) {
|
||||
mutex.withLock {
|
||||
buffer[startIndex.forward(size)] = element
|
||||
if (isFull()) {
|
||||
startIndex++
|
||||
} else {
|
||||
size++
|
||||
}
|
||||
if (isFull()) startIndex++ else size++
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,8 +66,8 @@ class RingBuffer<T>(
|
||||
@Suppress("NOTHING_TO_INLINE")
|
||||
private inline fun Int.forward(n: Int): Int = (this + n) % (buffer.size)
|
||||
|
||||
companion object {
|
||||
inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> {
|
||||
public companion object {
|
||||
public inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> {
|
||||
val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer<T?>
|
||||
return RingBuffer(buffer)
|
||||
}
|
||||
@ -81,7 +75,7 @@ class RingBuffer<T>(
|
||||
/**
|
||||
* Slow yet universal buffer
|
||||
*/
|
||||
fun <T> boxing(size: Int): RingBuffer<T> {
|
||||
public fun <T> boxing(size: Int): RingBuffer<T> {
|
||||
val buffer: MutableBuffer<T?> = MutableBuffer.boxing(size) { null }
|
||||
return RingBuffer(buffer)
|
||||
}
|
||||
|
@ -118,39 +118,39 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
||||
other: DMatrix<T, C1, C2>
|
||||
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
||||
|
||||
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||
DPoint.coerceUnsafe(context { this@dot dot vector })
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
|
||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
|
||||
context { this@times.times(value) }.coerce()
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
||||
public inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
|
||||
m * this
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||
context { this@plus + other }.coerce()
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
|
||||
context { this@minus + other }.coerce()
|
||||
|
||||
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
|
||||
public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
|
||||
context { this@unaryMinus.unaryMinus() }.coerce()
|
||||
|
||||
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||
|
||||
/**
|
||||
* A square unit matrix
|
||||
*/
|
||||
inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
||||
public inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
||||
if (i == j) context.elementContext.one else context.elementContext.zero
|
||||
}
|
||||
|
||||
inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
||||
context.elementContext.zero
|
||||
}
|
||||
|
||||
companion object {
|
||||
val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
|
||||
public companion object {
|
||||
public val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
|
||||
}
|
||||
}
|
||||
|
@ -12,39 +12,38 @@ import scientifik.kmath.structures.asBuffer
|
||||
import scientifik.kmath.structures.asIterable
|
||||
import kotlin.math.sqrt
|
||||
|
||||
typealias RealPoint = Point<Double>
|
||||
public typealias RealPoint = Point<Double>
|
||||
|
||||
fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer())
|
||||
fun List<Double>.asVector(): RealVector = RealVector(this.asBuffer())
|
||||
public fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer())
|
||||
public fun List<Double>.asVector(): RealVector = RealVector(this.asBuffer())
|
||||
|
||||
object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||
public object VectorL2Norm : Norm<Point<out Number>, Double> {
|
||||
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
|
||||
}
|
||||
|
||||
inline class RealVector(private val point: Point<Double>) :
|
||||
public inline class RealVector(private val point: Point<Double>) :
|
||||
SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint {
|
||||
public override val size: Int get() = point.size
|
||||
public override val context: VectorSpace<Double, RealField> get() = space(point.size)
|
||||
|
||||
override val context: VectorSpace<Double, RealField> get() = space(point.size)
|
||||
public override fun unwrap(): RealPoint = point
|
||||
|
||||
override fun unwrap(): RealPoint = point
|
||||
public override fun RealPoint.wrap(): RealVector = RealVector(this)
|
||||
|
||||
override fun RealPoint.wrap(): RealVector = RealVector(this)
|
||||
|
||||
override val size: Int get() = point.size
|
||||
|
||||
override operator fun get(index: Int): Double = point[index]
|
||||
|
||||
override operator fun iterator(): Iterator<Double> = point.iterator()
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
private val spaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
|
||||
|
||||
inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector =
|
||||
public inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector =
|
||||
RealVector(RealBuffer(dim, initializer))
|
||||
|
||||
operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
||||
public operator fun invoke(vararg values: Double): RealVector = values.asVector()
|
||||
|
||||
fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
|
||||
public fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
|
||||
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
|
||||
}
|
||||
}
|
||||
|
@ -5,4 +5,4 @@ import scientifik.kmath.structures.RealBuffer
|
||||
/**
|
||||
* Simplified [RealBuffer] to array comparison
|
||||
*/
|
||||
fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles)
|
||||
public fun RealBuffer.contentEquals(vararg doubles: Double): Boolean = array.contentEquals(doubles)
|
||||
|
@ -138,11 +138,11 @@ public fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
}
|
||||
|
||||
public fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
columns[j].asIterable().min() ?: error("Cannot produce min on empty column")
|
||||
columns[j].asIterable().minOrNull() ?: error("Cannot produce min on empty column")
|
||||
}
|
||||
|
||||
public fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
columns[j].asIterable().max() ?: error("Cannot produce min on empty column")
|
||||
columns[j].asIterable().maxOrNull() ?: error("Cannot produce min on empty column")
|
||||
}
|
||||
|
||||
public fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
@ -154,6 +154,6 @@ public fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j
|
||||
*/
|
||||
|
||||
public fun Matrix<Double>.sum(): Double = elements().map { (_, value) -> value }.sum()
|
||||
public fun Matrix<Double>.min(): Double? = elements().map { (_, value) -> value }.min()
|
||||
public fun Matrix<Double>.max(): Double? = elements().map { (_, value) -> value }.max()
|
||||
public fun Matrix<Double>.min(): Double? = elements().map { (_, value) -> value }.minOrNull()
|
||||
public fun Matrix<Double>.max(): Double? = elements().map { (_, value) -> value }.maxOrNull()
|
||||
public fun Matrix<Double>.average(): Double = elements().map { (_, value) -> value }.average()
|
||||
|
@ -36,8 +36,10 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring
|
||||
/**
|
||||
* Represent a polynomial as a context-dependent function
|
||||
*/
|
||||
public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> =
|
||||
MathFunction { arg -> value(this, arg) }
|
||||
public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, C, T> =
|
||||
object : MathFunction<T, C, T> {
|
||||
override fun C.invoke(arg: T): T = value(this, arg)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent the polynomial as a regular context-less function
|
||||
|
@ -3,13 +3,14 @@ package scientifik.kmath.functions
|
||||
import scientifik.kmath.operations.Algebra
|
||||
import scientifik.kmath.operations.RealField
|
||||
|
||||
// TODO make fun interface when KT-41770 is fixed
|
||||
/**
|
||||
* A regular function that could be called only inside specific algebra context
|
||||
* @param T source type
|
||||
* @param C source algebra constraint
|
||||
* @param R result type
|
||||
*/
|
||||
public fun interface MathFunction<T, C : Algebra<T>, R> {
|
||||
public /*fun*/ interface MathFunction<T, C : Algebra<T>, R> {
|
||||
public operator fun C.invoke(arg: T): R
|
||||
}
|
||||
|
||||
|
@ -15,15 +15,11 @@ public class SplineInterpolator<T : Comparable<T>>(
|
||||
public override val algebra: Field<T>,
|
||||
public val bufferFactory: MutableBufferFactory<T>
|
||||
) : PolynomialInterpolator<T> {
|
||||
|
||||
//TODO possibly optimize zeroed buffers
|
||||
|
||||
public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
|
||||
if (points.size < 3) {
|
||||
error("Can't use spline interpolator with less than 3 points")
|
||||
}
|
||||
require(points.size >= 3) { "Can't use spline interpolator with less than 3 points" }
|
||||
insureSorted(points)
|
||||
|
||||
// Number of intervals. The number of data points is n + 1.
|
||||
val n = points.size - 1
|
||||
// Differences between knot points
|
||||
@ -34,6 +30,7 @@ public class SplineInterpolator<T : Comparable<T>>(
|
||||
for (i in 1 until n) {
|
||||
val g = 2.0 * (points.x[i + 1] - points.x[i - 1]) - h[i - 1] * mu[i - 1]
|
||||
mu[i] = h[i] / g
|
||||
|
||||
z[i] =
|
||||
(3.0 * (points.y[i + 1] * h[i - 1] - points.x[i] * (points.x[i + 1] - points.x[i - 1]) + points.y[i - 1] * h[i]) / (h[i - 1] * h[i])
|
||||
- h[i - 1] * z[i - 1]) / g
|
||||
@ -54,7 +51,5 @@ public class SplineInterpolator<T : Comparable<T>>(
|
||||
putLeft(points.x[j], polynomial)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -5,16 +5,16 @@ package scientifik.kmath.histogram
|
||||
* TODO replace with atomics
|
||||
*/
|
||||
|
||||
expect class LongCounter() {
|
||||
fun decrement()
|
||||
fun increment()
|
||||
fun reset()
|
||||
fun sum(): Long
|
||||
fun add(l: Long)
|
||||
public expect class LongCounter() {
|
||||
public fun decrement()
|
||||
public fun increment()
|
||||
public fun reset()
|
||||
public fun sum(): Long
|
||||
public fun add(l: Long)
|
||||
}
|
||||
|
||||
expect class DoubleCounter() {
|
||||
fun reset()
|
||||
fun sum(): Double
|
||||
fun add(d: Double)
|
||||
}
|
||||
public expect class DoubleCounter() {
|
||||
public fun reset()
|
||||
public fun sum(): Double
|
||||
public fun add(d: Double)
|
||||
}
|
||||
|
@ -11,49 +11,47 @@ import kotlin.contracts.contract
|
||||
/**
|
||||
* The bin in the histogram. The histogram is by definition always done in the real space
|
||||
*/
|
||||
interface Bin<T : Any> : Domain<T> {
|
||||
public interface Bin<T : Any> : Domain<T> {
|
||||
/**
|
||||
* The value of this bin
|
||||
*/
|
||||
val value: Number
|
||||
val center: Point<T>
|
||||
public val value: Number
|
||||
public val center: Point<T>
|
||||
}
|
||||
|
||||
interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
|
||||
|
||||
public interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
|
||||
/**
|
||||
* Find existing bin, corresponding to given coordinates
|
||||
*/
|
||||
operator fun get(point: Point<out T>): B?
|
||||
public operator fun get(point: Point<out T>): B?
|
||||
|
||||
/**
|
||||
* Dimension of the histogram
|
||||
*/
|
||||
val dimension: Int
|
||||
public val dimension: Int
|
||||
|
||||
}
|
||||
|
||||
interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
||||
public interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
||||
|
||||
/**
|
||||
* Increment appropriate bin
|
||||
*/
|
||||
fun putWithWeight(point: Point<out T>, weight: Double)
|
||||
public fun putWithWeight(point: Point<out T>, weight: Double)
|
||||
|
||||
fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0)
|
||||
public fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0)
|
||||
}
|
||||
|
||||
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point))
|
||||
public fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point))
|
||||
|
||||
fun MutableHistogram<Double, *>.put(vararg point: Number): Unit =
|
||||
public fun MutableHistogram<Double, *>.put(vararg point: Number): Unit =
|
||||
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
||||
|
||||
fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point))
|
||||
|
||||
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
|
||||
public fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point))
|
||||
public fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
|
||||
|
||||
/**
|
||||
* Pass a sequence builder into histogram
|
||||
*/
|
||||
fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
|
||||
public fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
|
||||
fill(sequence(block).asIterable())
|
||||
|
@ -7,9 +7,12 @@ import scientifik.kmath.real.asVector
|
||||
import scientifik.kmath.structures.*
|
||||
import kotlin.math.floor
|
||||
|
||||
|
||||
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
|
||||
fun contains(vector: Point<out T>): Boolean {
|
||||
public data class BinDef<T : Comparable<T>>(
|
||||
public val space: SpaceOperations<Point<T>>,
|
||||
public val center: Point<T>,
|
||||
public val sizes: Point<T>
|
||||
) {
|
||||
public fun contains(vector: Point<out T>): Boolean {
|
||||
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
|
||||
val upper = space { center + sizes / 2.0 }
|
||||
val lower = space { center - sizes / 2.0 }
|
||||
@ -18,21 +21,20 @@ data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val c
|
||||
}
|
||||
|
||||
|
||||
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
||||
override operator fun contains(point: Point<T>): Boolean = def.contains(point)
|
||||
|
||||
override val dimension: Int
|
||||
public class MultivariateBin<T : Comparable<T>>(public val def: BinDef<T>, public override val value: Number) : Bin<T> {
|
||||
public override val dimension: Int
|
||||
get() = def.center.size
|
||||
|
||||
override val center: Point<T>
|
||||
public override val center: Point<T>
|
||||
get() = def.center
|
||||
|
||||
public override operator fun contains(point: Point<T>): Boolean = def.contains(point)
|
||||
}
|
||||
|
||||
/**
|
||||
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
|
||||
*/
|
||||
class RealHistogram(
|
||||
public class RealHistogram(
|
||||
private val lower: Buffer<Double>,
|
||||
private val upper: Buffer<Double>,
|
||||
private val binNums: IntArray = IntArray(lower.size) { 20 }
|
||||
@ -40,7 +42,7 @@ class RealHistogram(
|
||||
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
|
||||
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
||||
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
|
||||
override val dimension: Int get() = lower.size
|
||||
public override val dimension: Int get() = lower.size
|
||||
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||
|
||||
init {
|
||||
@ -64,7 +66,7 @@ class RealHistogram(
|
||||
|
||||
private fun getValue(index: IntArray): Long = values[index].sum()
|
||||
|
||||
fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
||||
public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
||||
|
||||
private fun getDef(index: IntArray): BinDef<Double> {
|
||||
val center = index.mapIndexed { axis, i ->
|
||||
@ -78,9 +80,9 @@ class RealHistogram(
|
||||
return BinDef(RealBufferFieldOperations, center, binSize)
|
||||
}
|
||||
|
||||
fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point))
|
||||
public fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point))
|
||||
|
||||
override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
|
||||
public override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
|
||||
val index = getIndex(point)
|
||||
return MultivariateBin(getDef(index), getValue(index))
|
||||
}
|
||||
@ -90,27 +92,27 @@ class RealHistogram(
|
||||
// values[index].increment()
|
||||
// }
|
||||
|
||||
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||
public override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||
val index = getIndex(point)
|
||||
values[index].increment()
|
||||
weights[index].add(weight)
|
||||
}
|
||||
|
||||
override operator fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) ->
|
||||
MultivariateBin(getDef(index), value.sum())
|
||||
}.iterator()
|
||||
public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
|
||||
weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
|
||||
.iterator()
|
||||
|
||||
/**
|
||||
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
||||
*/
|
||||
fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() }
|
||||
public fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() }
|
||||
|
||||
/**
|
||||
* Sum of weights
|
||||
*/
|
||||
fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() }
|
||||
public fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() }
|
||||
|
||||
companion object {
|
||||
public companion object {
|
||||
/**
|
||||
* Use it like
|
||||
* ```
|
||||
@ -120,9 +122,9 @@ class RealHistogram(
|
||||
*)
|
||||
*```
|
||||
*/
|
||||
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram = RealHistogram(
|
||||
ranges.map { it.start }.asVector(),
|
||||
ranges.map { it.endInclusive }.asVector()
|
||||
public fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram = RealHistogram(
|
||||
ranges.map(ClosedFloatingPointRange<Double>::start).asVector(),
|
||||
ranges.map(ClosedFloatingPointRange<Double>::endInclusive).asVector()
|
||||
)
|
||||
|
||||
/**
|
||||
@ -134,10 +136,21 @@ class RealHistogram(
|
||||
*)
|
||||
*```
|
||||
*/
|
||||
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram = RealHistogram(
|
||||
ListBuffer(ranges.map { it.first.start }),
|
||||
ListBuffer(ranges.map { it.first.endInclusive }),
|
||||
ranges.map { it.second }.toIntArray()
|
||||
)
|
||||
public fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram =
|
||||
RealHistogram(
|
||||
ListBuffer(
|
||||
ranges
|
||||
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
||||
.map(ClosedFloatingPointRange<Double>::start)
|
||||
),
|
||||
|
||||
ListBuffer(
|
||||
ranges
|
||||
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
||||
.map(ClosedFloatingPointRange<Double>::endInclusive)
|
||||
),
|
||||
|
||||
ranges.map(Pair<ClosedFloatingPointRange<Double>, Int>::second).toIntArray()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFalse
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class MultivariateHistogramTest {
|
||||
internal class MultivariateHistogramTest {
|
||||
@Test
|
||||
fun testSinglePutHistogram() {
|
||||
val histogram = RealHistogram.fromRanges(
|
||||
|
@ -1,33 +1,37 @@
|
||||
package scientifik.kmath.histogram
|
||||
|
||||
actual class LongCounter {
|
||||
private var sum: Long = 0
|
||||
actual fun decrement() {
|
||||
public actual class LongCounter {
|
||||
private var sum: Long = 0L
|
||||
|
||||
public actual fun decrement() {
|
||||
sum--
|
||||
}
|
||||
|
||||
actual fun increment() {
|
||||
public actual fun increment() {
|
||||
sum++
|
||||
}
|
||||
|
||||
actual fun reset() {
|
||||
public actual fun reset() {
|
||||
sum = 0
|
||||
}
|
||||
|
||||
actual fun sum(): Long = sum
|
||||
actual fun add(l: Long) {
|
||||
public actual fun sum(): Long = sum
|
||||
|
||||
public actual fun add(l: Long) {
|
||||
sum += l
|
||||
}
|
||||
}
|
||||
|
||||
actual class DoubleCounter {
|
||||
public actual class DoubleCounter {
|
||||
private var sum: Double = 0.0
|
||||
actual fun reset() {
|
||||
|
||||
public actual fun reset() {
|
||||
sum = 0.0
|
||||
}
|
||||
|
||||
actual fun sum(): Double = sum
|
||||
actual fun add(d: Double) {
|
||||
public actual fun sum(): Double = sum
|
||||
|
||||
public actual fun add(d: Double) {
|
||||
sum += d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,5 +3,5 @@ package scientifik.kmath.histogram
|
||||
import java.util.concurrent.atomic.DoubleAdder
|
||||
import java.util.concurrent.atomic.LongAdder
|
||||
|
||||
actual typealias LongCounter = LongAdder
|
||||
actual typealias DoubleCounter = DoubleAdder
|
||||
public actual typealias LongCounter = LongAdder
|
||||
public actual typealias DoubleCounter = DoubleAdder
|
||||
|
@ -8,25 +8,26 @@ import kotlin.math.floor
|
||||
|
||||
//TODO move to common
|
||||
|
||||
class UnivariateBin(val position: Double, val size: Double, val counter: LongCounter = LongCounter()) : Bin<Double> {
|
||||
public class UnivariateBin(
|
||||
public val position: Double,
|
||||
public val size: Double,
|
||||
public val counter: LongCounter = LongCounter()
|
||||
) : Bin<Double> {
|
||||
//TODO add weighting
|
||||
override val value: Number get() = counter.sum()
|
||||
public override val value: Number get() = counter.sum()
|
||||
|
||||
override val center: RealVector get() = doubleArrayOf(position).asVector()
|
||||
public override val center: RealVector get() = doubleArrayOf(position).asVector()
|
||||
public override val dimension: Int get() = 1
|
||||
|
||||
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
|
||||
|
||||
override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
|
||||
|
||||
internal operator fun inc() = this.also { counter.increment() }
|
||||
|
||||
override val dimension: Int get() = 1
|
||||
public operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
|
||||
public override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
|
||||
internal operator fun inc(): UnivariateBin = this.also { counter.increment() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Univariate histogram with log(n) bin search speed
|
||||
*/
|
||||
class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) :
|
||||
public class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) :
|
||||
MutableHistogram<Double, UnivariateBin> {
|
||||
|
||||
private val bins: TreeMap<Double, UnivariateBin> = TreeMap()
|
||||
@ -46,16 +47,16 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
||||
synchronized(this) { bins.put(it.position, it) }
|
||||
}
|
||||
|
||||
override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
|
||||
public override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
|
||||
|
||||
override val dimension: Int get() = 1
|
||||
public override val dimension: Int get() = 1
|
||||
|
||||
override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
|
||||
public override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
|
||||
|
||||
/**
|
||||
* Thread safe put operation
|
||||
*/
|
||||
fun put(value: Double) {
|
||||
public fun put(value: Double) {
|
||||
(get(value) ?: createBin(value)).inc()
|
||||
}
|
||||
|
||||
@ -64,13 +65,13 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
||||
put(point[0])
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value ->
|
||||
public companion object {
|
||||
public fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value ->
|
||||
val center = start + binSize * floor((value - start) / binSize + 0.5)
|
||||
UnivariateBin(center, binSize)
|
||||
}
|
||||
|
||||
fun custom(borders: DoubleArray): UnivariateHistogram {
|
||||
public fun custom(borders: DoubleArray): UnivariateHistogram {
|
||||
val sorted = borders.sortedArray()
|
||||
|
||||
return UnivariateHistogram { value ->
|
||||
@ -79,10 +80,12 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
||||
Double.NEGATIVE_INFINITY,
|
||||
Double.MAX_VALUE
|
||||
)
|
||||
|
||||
value > sorted.last() -> UnivariateBin(
|
||||
Double.POSITIVE_INFINITY,
|
||||
Double.MAX_VALUE
|
||||
)
|
||||
|
||||
else -> {
|
||||
val index = (0 until sorted.size).first { value > sorted[it] }
|
||||
val left = sorted[index]
|
||||
@ -95,4 +98,4 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
||||
}
|
||||
}
|
||||
|
||||
fun UnivariateHistogram.fill(sequence: Iterable<Double>) = sequence.forEach { put(it) }
|
||||
public fun UnivariateHistogram.fill(sequence: Iterable<Double>): Unit = sequence.forEach(::put)
|
||||
|
@ -5,19 +5,19 @@ import scientifik.kmath.chains.collect
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.BufferFactory
|
||||
|
||||
interface Sampler<T : Any> {
|
||||
fun sample(generator: RandomGenerator): Chain<T>
|
||||
public interface Sampler<T : Any> {
|
||||
public fun sample(generator: RandomGenerator): Chain<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* A distribution of typed objects
|
||||
*/
|
||||
interface Distribution<T : Any> : Sampler<T> {
|
||||
public interface Distribution<T : Any> : Sampler<T> {
|
||||
/**
|
||||
* A probability value for given argument [arg].
|
||||
* For continuous distributions returns PDF
|
||||
*/
|
||||
fun probability(arg: T): Double
|
||||
public fun probability(arg: T): Double
|
||||
|
||||
/**
|
||||
* Create a chain of samples from this distribution.
|
||||
@ -28,20 +28,20 @@ interface Distribution<T : Any> : Sampler<T> {
|
||||
/**
|
||||
* An empty companion. Distribution factories should be written as its extensions
|
||||
*/
|
||||
companion object
|
||||
public companion object
|
||||
}
|
||||
|
||||
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||
public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||
/**
|
||||
* Cumulative distribution for ordered parameter (CDF)
|
||||
*/
|
||||
fun cumulative(arg: T): Double
|
||||
public fun cumulative(arg: T): Double
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute probability integral in an interval
|
||||
*/
|
||||
fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
|
||||
public fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
|
||||
require(to > from)
|
||||
return cumulative(to) - cumulative(from)
|
||||
}
|
||||
@ -49,7 +49,7 @@ fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Doub
|
||||
/**
|
||||
* Sample a bunch of values
|
||||
*/
|
||||
fun <T : Any> Sampler<T>.sampleBuffer(
|
||||
public fun <T : Any> Sampler<T>.sampleBuffer(
|
||||
generator: RandomGenerator,
|
||||
size: Int,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
|
||||
@ -57,6 +57,7 @@ fun <T : Any> Sampler<T>.sampleBuffer(
|
||||
require(size > 1)
|
||||
//creating temporary storage once
|
||||
val tmp = ArrayList<T>(size)
|
||||
|
||||
return sample(generator).collect { chain ->
|
||||
//clear list from previous run
|
||||
tmp.clear()
|
||||
@ -72,5 +73,5 @@ fun <T : Any> Sampler<T>.sampleBuffer(
|
||||
/**
|
||||
* Generate a bunch of samples from real distributions
|
||||
*/
|
||||
fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int) =
|
||||
sampleBuffer(generator, size, Buffer.Companion::real)
|
||||
public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int) =
|
||||
sampleBuffer(generator, size, Buffer.Companion::real)
|
||||
|
@ -12,33 +12,29 @@ import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
public abstract class ContinuousSamplerDistribution : Distribution<Double> {
|
||||
|
||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
|
||||
private val sampler = buildCMSampler(generator)
|
||||
|
||||
override fun nextDouble(): Double = sampler.sample()
|
||||
|
||||
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
||||
public override fun nextDouble(): Double = sampler.sample()
|
||||
public override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
|
||||
}
|
||||
|
||||
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
|
||||
|
||||
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
|
||||
public override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
|
||||
}
|
||||
|
||||
public abstract class DiscreteSamplerDistribution : Distribution<Int> {
|
||||
|
||||
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
|
||||
private val sampler = buildSampler(generator)
|
||||
|
||||
override fun nextInt(): Int = sampler.sample()
|
||||
|
||||
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
||||
public override fun nextInt(): Int = sampler.sample()
|
||||
public override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
|
||||
}
|
||||
|
||||
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
|
||||
|
||||
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
|
||||
public override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
|
||||
}
|
||||
|
||||
public enum class NormalSamplerMethod {
|
||||
@ -58,7 +54,7 @@ public fun Distribution.Companion.normal(
|
||||
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
|
||||
): Distribution<Double> = object : ContinuousSamplerDistribution() {
|
||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||
val provider = generator.asUniformRandomProvider()
|
||||
return normalSampler(method, provider)
|
||||
}
|
||||
|
||||
@ -76,34 +72,27 @@ public fun Distribution.Companion.normal(
|
||||
private val norm = sigma * sqrt(PI * 2)
|
||||
|
||||
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
|
||||
val provider: UniformRandomProvider = generator.asUniformRandomProvider()
|
||||
val provider = generator.asUniformRandomProvider()
|
||||
val normalizedSampler = normalSampler(method, provider)
|
||||
return GaussianSampler(normalizedSampler, mean, sigma)
|
||||
}
|
||||
|
||||
override fun probability(arg: Double): Double {
|
||||
return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
|
||||
}
|
||||
override fun probability(arg: Double): Double = exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
|
||||
}
|
||||
|
||||
public fun Distribution.Companion.poisson(
|
||||
lambda: Double
|
||||
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() {
|
||||
public fun Distribution.Companion.poisson(lambda: Double): DiscreteSamplerDistribution =
|
||||
object : DiscreteSamplerDistribution() {
|
||||
private val computedProb: MutableMap<Int, Double> = hashMapOf(0 to exp(-lambda))
|
||||
|
||||
override fun buildSampler(generator: RandomGenerator): DiscreteSampler {
|
||||
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
||||
}
|
||||
override fun buildSampler(generator: RandomGenerator): DiscreteSampler =
|
||||
PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
|
||||
|
||||
private val computedProb: HashMap<Int, Double> = hashMapOf(0 to exp(-lambda))
|
||||
override fun probability(arg: Int): Double {
|
||||
require(arg >= 0) { "The argument must be >= 0" }
|
||||
|
||||
override fun probability(arg: Int): Double {
|
||||
require(arg >= 0) { "The argument must be >= 0" }
|
||||
return if (arg > 40) {
|
||||
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
|
||||
} else {
|
||||
computedProb.getOrPut(arg) {
|
||||
probability(arg - 1) * lambda / arg
|
||||
}
|
||||
return if (arg > 40)
|
||||
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
|
||||
else
|
||||
computedProb.getOrPut(arg) { probability(arg - 1) * lambda / arg }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -27,7 +27,6 @@ include(
|
||||
":kmath-memory",
|
||||
":kmath-core",
|
||||
":kmath-functions",
|
||||
// ":kmath-io",
|
||||
":kmath-coroutines",
|
||||
":kmath-histograms",
|
||||
":kmath-commons",
|
||||
@ -38,6 +37,6 @@ include(
|
||||
":kmath-dimensions",
|
||||
":kmath-for-real",
|
||||
":kmath-geometry",
|
||||
":kmath-ast",
|
||||
// ":kmath-ast",
|
||||
":examples"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user