Specify explicit API visbility, minor refactoring (error handling, etc.)

This commit is contained in:
Iaroslav Postovalov 2020-09-09 11:28:54 +07:00
parent 6b79e79d21
commit f567f73d19
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
30 changed files with 401 additions and 415 deletions

View File

@ -20,7 +20,7 @@ repositories {
sourceSets.register("benchmarks") sourceSets.register("benchmarks")
dependencies { dependencies {
implementation(project(":kmath-ast")) // implementation(project(":kmath-ast"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))

View File

@ -1,6 +1,5 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
@ -13,11 +12,10 @@ public interface ExtendedFieldOperations<T> :
HyperbolicOperations<T>, HyperbolicOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> { ExponentialOperations<T> {
public override fun tan(arg: T): T = sin(arg) / cos(arg)
public override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
override fun tan(arg: T): T = sin(arg) / cos(arg) public override fun unaryOperation(operation: String, arg: T): T = when (operation) {
override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
override fun unaryOperation(operation: String, arg: T): T = when (operation) {
TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.COS_OPERATION -> cos(arg)
TrigonometricOperations.SIN_OPERATION -> sin(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg)
TrigonometricOperations.TAN_OPERATION -> tan(arg) TrigonometricOperations.TAN_OPERATION -> tan(arg)
@ -37,19 +35,18 @@ public interface ExtendedFieldOperations<T> :
} }
} }
/** /**
* Advanced Number-like field that implements basic operations. * Advanced Number-like field that implements basic operations.
*/ */
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> { public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2 public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2 public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg) public override fun asinh(arg: T): T = ln(sqrt(arg * arg + one) + arg)
override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one))) public override fun acosh(arg: T): T = ln(arg + sqrt((arg - one) * (arg + one)))
override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2 public override fun atanh(arg: T): T = (ln(arg + one) - ln(one - arg)) / 2
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { public override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
PowerOperations.POW_OPERATION -> power(left, right) PowerOperations.POW_OPERATION -> power(left, right)
else -> super.rightSideNumberOperation(operation, left, right) else -> super.rightSideNumberOperation(operation, left, right)
} }
@ -63,12 +60,11 @@ public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
public inline class Real(public val value: Double) : FieldElement<Double, Real, RealField> { public inline class Real(public val value: Double) : FieldElement<Double, Real, RealField> {
override val context: RealField public override val context: RealField
get() = RealField get() = RealField
override fun unwrap(): Double = value public override fun unwrap(): Double = value
public override fun Double.wrap(): Real = Real(value)
override fun Double.wrap(): Real = Real(value)
public companion object public companion object
} }
@ -78,49 +74,49 @@ public inline class Real(public val value: Double) : FieldElement<Double, Real,
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object RealField : ExtendedField<Double>, Norm<Double, Double> { public object RealField : ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double public override val zero: Double
get() = 0.0 get() = 0.0
override val one: Double public override val one: Double
get() = 1.0 get() = 1.0
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) { public override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation, left, right)
} }
override inline fun add(a: Double, b: Double): Double = a + b public override inline fun add(a: Double, b: Double): Double = a + b
override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble() public override inline fun multiply(a: Double, k: Number): Double = a * k.toDouble()
override inline fun multiply(a: Double, b: Double): Double = a * b public override inline fun multiply(a: Double, b: Double): Double = a * b
override inline fun divide(a: Double, b: Double): Double = a / b public override inline fun divide(a: Double, b: Double): Double = a / b
override inline fun sin(arg: Double): Double = kotlin.math.sin(arg) public override inline fun sin(arg: Double): Double = kotlin.math.sin(arg)
override inline fun cos(arg: Double): Double = kotlin.math.cos(arg) public override inline fun cos(arg: Double): Double = kotlin.math.cos(arg)
override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) public override inline fun tan(arg: Double): Double = kotlin.math.tan(arg)
override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) public override inline fun acos(arg: Double): Double = kotlin.math.acos(arg)
override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) public override inline fun asin(arg: Double): Double = kotlin.math.asin(arg)
override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) public override inline fun atan(arg: Double): Double = kotlin.math.atan(arg)
override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg) public override inline fun sinh(arg: Double): Double = kotlin.math.sinh(arg)
override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg) public override inline fun cosh(arg: Double): Double = kotlin.math.cosh(arg)
override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg) public override inline fun tanh(arg: Double): Double = kotlin.math.tanh(arg)
override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg) public override inline fun asinh(arg: Double): Double = kotlin.math.asinh(arg)
override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg) public override inline fun acosh(arg: Double): Double = kotlin.math.acosh(arg)
override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg) public override inline fun atanh(arg: Double): Double = kotlin.math.atanh(arg)
override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble()) public override inline fun power(arg: Double, pow: Number): Double = arg.kpow(pow.toDouble())
override inline fun exp(arg: Double): Double = kotlin.math.exp(arg) public override inline fun exp(arg: Double): Double = kotlin.math.exp(arg)
override inline fun ln(arg: Double): Double = kotlin.math.ln(arg) public override inline fun ln(arg: Double): Double = kotlin.math.ln(arg)
override inline fun norm(arg: Double): Double = abs(arg) public override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus(): Double = -this public override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b public override inline fun Double.plus(b: Double): Double = this + b
override inline fun Double.minus(b: Double): Double = this - b public override inline fun Double.minus(b: Double): Double = this - b
override inline fun Double.times(b: Double): Double = this * b public override inline fun Double.times(b: Double): Double = this * b
override inline fun Double.div(b: Double): Double = this / b public override inline fun Double.div(b: Double): Double = this / b
} }
/** /**
@ -128,49 +124,49 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object FloatField : ExtendedField<Float>, Norm<Float, Float> { public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val zero: Float public override val zero: Float
get() = 0.0f get() = 0.0f
override val one: Float public override val one: Float
get() = 1.0f get() = 1.0f
override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) { public override fun binaryOperation(operation: String, left: Float, right: Float): Float = when (operation) {
PowerOperations.POW_OPERATION -> left pow right PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right) else -> super.binaryOperation(operation, left, right)
} }
override inline fun add(a: Float, b: Float): Float = a + b public override inline fun add(a: Float, b: Float): Float = a + b
override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat() public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
override inline fun multiply(a: Float, b: Float): Float = a * b public override inline fun multiply(a: Float, b: Float): Float = a * b
override inline fun divide(a: Float, b: Float): Float = a / b public override inline fun divide(a: Float, b: Float): Float = a / b
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) public override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) public override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
override inline fun tan(arg: Float): Float = kotlin.math.tan(arg) public override inline fun tan(arg: Float): Float = kotlin.math.tan(arg)
override inline fun acos(arg: Float): Float = kotlin.math.acos(arg) public override inline fun acos(arg: Float): Float = kotlin.math.acos(arg)
override inline fun asin(arg: Float): Float = kotlin.math.asin(arg) public override inline fun asin(arg: Float): Float = kotlin.math.asin(arg)
override inline fun atan(arg: Float): Float = kotlin.math.atan(arg) public override inline fun atan(arg: Float): Float = kotlin.math.atan(arg)
override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg) public override inline fun sinh(arg: Float): Float = kotlin.math.sinh(arg)
override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg) public override inline fun cosh(arg: Float): Float = kotlin.math.cosh(arg)
override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg) public override inline fun tanh(arg: Float): Float = kotlin.math.tanh(arg)
override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg) public override inline fun asinh(arg: Float): Float = kotlin.math.asinh(arg)
override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg) public override inline fun acosh(arg: Float): Float = kotlin.math.acosh(arg)
override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg) public override inline fun atanh(arg: Float): Float = kotlin.math.atanh(arg)
override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat()) public override inline fun power(arg: Float, pow: Number): Float = arg.kpow(pow.toFloat())
override inline fun exp(arg: Float): Float = kotlin.math.exp(arg) public override inline fun exp(arg: Float): Float = kotlin.math.exp(arg)
override inline fun ln(arg: Float): Float = kotlin.math.ln(arg) public override inline fun ln(arg: Float): Float = kotlin.math.ln(arg)
override inline fun norm(arg: Float): Float = abs(arg) public override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this public override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float): Float = this + b public override inline fun Float.plus(b: Float): Float = this + b
override inline fun Float.minus(b: Float): Float = this - b public override inline fun Float.minus(b: Float): Float = this - b
override inline fun Float.times(b: Float): Float = this * b public override inline fun Float.times(b: Float): Float = this * b
override inline fun Float.div(b: Float): Float = this / b public override inline fun Float.div(b: Float): Float = this / b
} }
/** /**
@ -178,23 +174,23 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object IntRing : Ring<Int>, Norm<Int, Int> { public object IntRing : Ring<Int>, Norm<Int, Int> {
override val zero: Int public override val zero: Int
get() = 0 get() = 0
override val one: Int public override val one: Int
get() = 1 get() = 1
override inline fun add(a: Int, b: Int): Int = a + b public override inline fun add(a: Int, b: Int): Int = a + b
override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
override inline fun multiply(a: Int, b: Int): Int = a * b public override inline fun multiply(a: Int, b: Int): Int = a * b
override inline fun norm(arg: Int): Int = abs(arg) public override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this public override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b public override inline fun Int.plus(b: Int): Int = this + b
override inline fun Int.minus(b: Int): Int = this - b public override inline fun Int.minus(b: Int): Int = this - b
override inline fun Int.times(b: Int): Int = this * b public override inline fun Int.times(b: Int): Int = this * b
} }
/** /**
@ -202,23 +198,23 @@ public object IntRing : Ring<Int>, Norm<Int, Int> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ShortRing : Ring<Short>, Norm<Short, Short> { public object ShortRing : Ring<Short>, Norm<Short, Short> {
override val zero: Short public override val zero: Short
get() = 0 get() = 0
override val one: Short public override val one: Short
get() = 1 get() = 1
override inline fun add(a: Short, b: Short): Short = (a + b).toShort() public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort() public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() public override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() public override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus(): Short = (-this).toShort() public override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short): Short = (this + b).toShort() public override inline fun Short.plus(b: Short): Short = (this + b).toShort()
override inline fun Short.minus(b: Short): Short = (this - b).toShort() public override inline fun Short.minus(b: Short): Short = (this - b).toShort()
override inline fun Short.times(b: Short): Short = (this * b).toShort() public override inline fun Short.times(b: Short): Short = (this * b).toShort()
} }
/** /**
@ -226,23 +222,23 @@ public object ShortRing : Ring<Short>, Norm<Short, Short> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> { public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
override val zero: Byte public override val zero: Byte
get() = 0 get() = 0
override val one: Byte public override val one: Byte
get() = 1 get() = 1
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte() public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() public override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() public override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus(): Byte = (-this).toByte() public override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() public override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() public override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() public override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
} }
/** /**
@ -250,21 +246,21 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public object LongRing : Ring<Long>, Norm<Long, Long> { public object LongRing : Ring<Long>, Norm<Long, Long> {
override val zero: Long public override val zero: Long
get() = 0 get() = 0
override val one: Long public override val one: Long
get() = 1 get() = 1
override inline fun add(a: Long, b: Long): Long = a + b public override inline fun add(a: Long, b: Long): Long = a + b
override inline fun multiply(a: Long, k: Number): Long = a * k.toLong() public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
override inline fun multiply(a: Long, b: Long): Long = a * b public override inline fun multiply(a: Long, b: Long): Long = a * b
override fun norm(arg: Long): Long = abs(arg) public override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus(): Long = (-this) public override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long): Long = (this + b) public override inline fun Long.plus(b: Long): Long = (this + b)
override inline fun Long.minus(b: Long): Long = (this - b) public override inline fun Long.minus(b: Long): Long = (this - b)
override inline fun Long.times(b: Long): Long = (this * b) public override inline fun Long.times(b: Long): Long = (this * b)
} }

View File

@ -4,27 +4,27 @@ import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
public class BoxingNDField<T, F : Field<T>>( public class BoxingNDField<T, F : Field<T>>(
override val shape: IntArray, public override val shape: IntArray,
override val elementContext: F, public override val elementContext: F,
public val bufferFactory: BufferFactory<T> public val bufferFactory: BufferFactory<T>
) : BufferedNDField<T, F> { ) : BufferedNDField<T, F> {
override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } } public override val zero: BufferedNDFieldElement<T, F> by lazy { produce { zero } }
override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } } public override val one: BufferedNDFieldElement<T, F> by lazy { produce { one } }
override val strides: Strides = DefaultStrides(shape) public override val strides: Strides = DefaultStrides(shape)
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) { public override fun check(vararg elements: NDBuffer<T>) {
check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
} }
override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> = public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
BufferedNDFieldElement( BufferedNDFieldElement(
this, this,
buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }) buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) })
override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> { public override fun map(arg: NDBuffer<T>, transform: F.(T) -> T): BufferedNDFieldElement<T, F> {
check(arg) check(arg)
return BufferedNDFieldElement( return BufferedNDFieldElement(
@ -36,7 +36,7 @@ public class BoxingNDField<T, F : Field<T>>(
} }
override fun mapIndexed( public override fun mapIndexed(
arg: NDBuffer<T>, arg: NDBuffer<T>,
transform: F.(index: IntArray, T) -> T transform: F.(index: IntArray, T) -> T
): BufferedNDFieldElement<T, F> { ): BufferedNDFieldElement<T, F> {
@ -55,7 +55,7 @@ public class BoxingNDField<T, F : Field<T>>(
// return BufferedNDFieldElement(this, buffer) // return BufferedNDFieldElement(this, buffer)
} }
override fun combine( public override fun combine(
a: NDBuffer<T>, a: NDBuffer<T>,
b: NDBuffer<T>, b: NDBuffer<T>,
transform: F.(T, T) -> T transform: F.(T, T) -> T
@ -66,7 +66,7 @@ public class BoxingNDField<T, F : Field<T>>(
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }) buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
} }
override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> = public override fun NDBuffer<T>.toElement(): FieldElement<NDBuffer<T>, *, out BufferedNDField<T, F>> =
BufferedNDFieldElement(this@BoxingNDField, buffer) BufferedNDFieldElement(this@BoxingNDField, buffer)
} }

View File

@ -21,7 +21,6 @@ public inline class LongBuffer(public val array: LongArray) : MutableBuffer<Long
override fun copy(): MutableBuffer<Long> = override fun copy(): MutableBuffer<Long> =
LongBuffer(array.copyOf()) LongBuffer(array.copyOf())
} }
/** /**

View File

@ -8,7 +8,7 @@ import kotlin.math.*
* [ExtendedFieldOperations] over [RealBuffer]. * [ExtendedFieldOperations] over [RealBuffer].
*/ */
public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> { public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { public override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
} }
@ -20,7 +20,7 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
} else RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } else RealBuffer(DoubleArray(a.size) { a[it] + b[it] })
} }
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer { public override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
val kValue = k.toDouble() val kValue = k.toDouble()
return if (a is RealBuffer) { return if (a is RealBuffer) {
@ -29,7 +29,7 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
} else RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } else RealBuffer(DoubleArray(a.size) { a[it] * kValue })
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { public override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
} }
@ -42,7 +42,7 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) RealBuffer(DoubleArray(a.size) { a[it] * b[it] })
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { public override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(b.size == a.size) { require(b.size == a.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
} }
@ -54,87 +54,87 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
} else RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } else RealBuffer(DoubleArray(a.size) { a[it] / b[it] })
} }
override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun sin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) RealBuffer(DoubleArray(arg.size) { sin(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } else RealBuffer(DoubleArray(arg.size) { sin(arg[it]) })
override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun cos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) RealBuffer(DoubleArray(arg.size) { cos(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) } else RealBuffer(DoubleArray(arg.size) { cos(arg[it]) })
override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun tan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) RealBuffer(DoubleArray(arg.size) { tan(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) } else RealBuffer(DoubleArray(arg.size) { tan(arg[it]) })
override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun asin(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) RealBuffer(DoubleArray(arg.size) { asin(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) RealBuffer(DoubleArray(arg.size) { asin(arg[it]) })
override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun acos(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { acos(array[it]) }) RealBuffer(DoubleArray(arg.size) { acos(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { acos(arg[it]) }) RealBuffer(DoubleArray(arg.size) { acos(arg[it]) })
override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun atan(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { atan(array[it]) }) RealBuffer(DoubleArray(arg.size) { atan(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) RealBuffer(DoubleArray(arg.size) { atan(arg[it]) })
override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun sinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { sinh(array[it]) }) RealBuffer(DoubleArray(arg.size) { sinh(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) }) RealBuffer(DoubleArray(arg.size) { sinh(arg[it]) })
override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun cosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { cosh(array[it]) }) RealBuffer(DoubleArray(arg.size) { cosh(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) }) RealBuffer(DoubleArray(arg.size) { cosh(arg[it]) })
override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun tanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { tanh(array[it]) }) RealBuffer(DoubleArray(arg.size) { tanh(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) }) RealBuffer(DoubleArray(arg.size) { tanh(arg[it]) })
override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun asinh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { asinh(array[it]) }) RealBuffer(DoubleArray(arg.size) { asinh(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) }) RealBuffer(DoubleArray(arg.size) { asinh(arg[it]) })
override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun acosh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { acosh(array[it]) }) RealBuffer(DoubleArray(arg.size) { acosh(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) }) RealBuffer(DoubleArray(arg.size) { acosh(arg[it]) })
override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun atanh(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { atanh(array[it]) }) RealBuffer(DoubleArray(arg.size) { atanh(array[it]) })
} else } else
RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) }) RealBuffer(DoubleArray(arg.size) { atanh(arg[it]) })
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) { public override fun power(arg: Buffer<Double>, pow: Number): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
} else } else
RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun exp(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) RealBuffer(DoubleArray(arg.size) { exp(array[it]) })
} else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) } else RealBuffer(DoubleArray(arg.size) { exp(arg[it]) })
override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) { public override fun ln(arg: Buffer<Double>): RealBuffer = if (arg is RealBuffer) {
val array = arg.array val array = arg.array
RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) RealBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else } else
@ -147,100 +147,100 @@ public object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>
* @property size the size of buffers to operate on. * @property size the size of buffers to operate on.
*/ */
public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double>> { public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double>> {
override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } } public override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } } public override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { public override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.add(a, b) return RealBufferFieldOperations.add(a, b)
} }
override fun multiply(a: Buffer<Double>, k: Number): RealBuffer { public override fun multiply(a: Buffer<Double>, k: Number): RealBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.multiply(a, k) return RealBufferFieldOperations.multiply(a, k)
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { public override fun multiply(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.multiply(a, b) return RealBufferFieldOperations.multiply(a, b)
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer { public override fun divide(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.divide(a, b) return RealBufferFieldOperations.divide(a, b)
} }
override fun sin(arg: Buffer<Double>): RealBuffer { public override fun sin(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sin(arg) return RealBufferFieldOperations.sin(arg)
} }
override fun cos(arg: Buffer<Double>): RealBuffer { public override fun cos(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cos(arg) return RealBufferFieldOperations.cos(arg)
} }
override fun tan(arg: Buffer<Double>): RealBuffer { public override fun tan(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.tan(arg) return RealBufferFieldOperations.tan(arg)
} }
override fun asin(arg: Buffer<Double>): RealBuffer { public override fun asin(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.asin(arg) return RealBufferFieldOperations.asin(arg)
} }
override fun acos(arg: Buffer<Double>): RealBuffer { public override fun acos(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.acos(arg) return RealBufferFieldOperations.acos(arg)
} }
override fun atan(arg: Buffer<Double>): RealBuffer { public override fun atan(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.atan(arg) return RealBufferFieldOperations.atan(arg)
} }
override fun sinh(arg: Buffer<Double>): RealBuffer { public override fun sinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sinh(arg) return RealBufferFieldOperations.sinh(arg)
} }
override fun cosh(arg: Buffer<Double>): RealBuffer { public override fun cosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cosh(arg) return RealBufferFieldOperations.cosh(arg)
} }
override fun tanh(arg: Buffer<Double>): RealBuffer { public override fun tanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.tanh(arg) return RealBufferFieldOperations.tanh(arg)
} }
override fun asinh(arg: Buffer<Double>): RealBuffer { public override fun asinh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.asinh(arg) return RealBufferFieldOperations.asinh(arg)
} }
override fun acosh(arg: Buffer<Double>): RealBuffer { public override fun acosh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.acosh(arg) return RealBufferFieldOperations.acosh(arg)
} }
override fun atanh(arg: Buffer<Double>): RealBuffer { public override fun atanh(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.atanh(arg) return RealBufferFieldOperations.atanh(arg)
} }
override fun power(arg: Buffer<Double>, pow: Number): RealBuffer { public override fun power(arg: Buffer<Double>, pow: Number): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.power(arg, pow) return RealBufferFieldOperations.power(arg, pow)
} }
override fun exp(arg: Buffer<Double>): RealBuffer { public override fun exp(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.exp(arg) return RealBufferFieldOperations.exp(arg)
} }
override fun ln(arg: Buffer<Double>): RealBuffer { public override fun ln(arg: Buffer<Double>): RealBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.ln(arg) return RealBufferFieldOperations.ln(arg)
} }

View File

@ -1,25 +1,21 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import kotlin.contracts.contract
/** /**
* Specialized [MutableBuffer] implementation over [ShortArray]. * Specialized [MutableBuffer] implementation over [ShortArray].
* *
* @property array the underlying array. * @property array the underlying array.
*/ */
public inline class ShortBuffer(public val array: ShortArray) : MutableBuffer<Short> { public inline class ShortBuffer(public val array: ShortArray) : MutableBuffer<Short> {
override val size: Int get() = array.size public override val size: Int get() = array.size
override operator fun get(index: Int): Short = array[index] public override operator fun get(index: Int): Short = array[index]
override operator fun set(index: Int, value: Short) { public override operator fun set(index: Int, value: Short) {
array[index] = value array[index] = value
} }
override operator fun iterator(): ShortIterator = array.iterator() public override operator fun iterator(): ShortIterator = array.iterator()
public override fun copy(): MutableBuffer<Short> = ShortBuffer(array.copyOf())
override fun copy(): MutableBuffer<Short> =
ShortBuffer(array.copyOf())
} }
/** /**

View File

@ -4,25 +4,24 @@ package scientifik.kmath.structures
* A structure that is guaranteed to be one-dimensional * A structure that is guaranteed to be one-dimensional
*/ */
public interface Structure1D<T> : NDStructure<T>, Buffer<T> { public interface Structure1D<T> : NDStructure<T>, Buffer<T> {
override val dimension: Int get() = 1 public override val dimension: Int get() = 1
override operator fun get(index: IntArray): T { public override operator fun get(index: IntArray): T {
require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" } require(index.size == 1) { "Index dimension mismatch. Expected 1 but found ${index.size}" }
return get(index[0]) return get(index[0])
} }
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator() public override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map(::get).iterator()
} }
/** /**
* A 1D wrapper for nd-structure * A 1D wrapper for nd-structure
*/ */
private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Structure1D<T> { private inline class Structure1DWrapper<T>(public val structure: NDStructure<T>) : Structure1D<T> {
override val shape: IntArray get() = structure.shape override val shape: IntArray get() = structure.shape
override val size: Int get() = structure.shape[0] override val size: Int get() = structure.shape[0]
override operator fun get(index: Int): T = structure[index] override operator fun get(index: Int): T = structure[index]
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements() override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
} }
@ -32,7 +31,6 @@ private inline class Structure1DWrapper<T>(val structure: NDStructure<T>) : Stru
*/ */
private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> { private inline class Buffer1DWrapper<T>(val buffer: Buffer<T>) : Structure1D<T> {
override val shape: IntArray get() = intArrayOf(buffer.size) override val shape: IntArray get() = intArrayOf(buffer.size)
override val size: Int get() = buffer.size override val size: Int get() = buffer.size
override fun elements(): Sequence<Pair<IntArray, T>> = override fun elements(): Sequence<Pair<IntArray, T>> =

View File

@ -8,19 +8,19 @@ import java.math.MathContext
* A field over [BigInteger]. * A field over [BigInteger].
*/ */
public object JBigIntegerField : Field<BigInteger> { public object JBigIntegerField : Field<BigInteger> {
override val zero: BigInteger public override val zero: BigInteger
get() = BigInteger.ZERO get() = BigInteger.ZERO
override val one: BigInteger public override val one: BigInteger
get() = BigInteger.ONE get() = BigInteger.ONE
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) public override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) public override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) public override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) public override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger()) public override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) public override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
override operator fun BigInteger.unaryMinus(): BigInteger = negate() public override operator fun BigInteger.unaryMinus(): BigInteger = negate()
} }
/** /**
@ -31,24 +31,24 @@ public object JBigIntegerField : Field<BigInteger> {
public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) : public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) :
Field<BigDecimal>, Field<BigDecimal>,
PowerOperations<BigDecimal> { PowerOperations<BigDecimal> {
override val zero: BigDecimal public override val zero: BigDecimal
get() = BigDecimal.ZERO get() = BigDecimal.ZERO
override val one: BigDecimal public override val one: BigDecimal
get() = BigDecimal.ONE get() = BigDecimal.ONE
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) public override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) public override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) public override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
override fun multiply(a: BigDecimal, k: Number): BigDecimal = public override fun multiply(a: BigDecimal, k: Number): BigDecimal =
a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext) a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext)
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) public override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) public override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) public override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) public override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) public override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
} }
/** /**

View File

@ -3,10 +3,10 @@ package scientifik.kmath.chains
/** /**
* Performance optimized chain for integer values * Performance optimized chain for integer values
*/ */
abstract class BlockingIntChain : Chain<Int> { public abstract class BlockingIntChain : Chain<Int> {
abstract fun nextInt(): Int public abstract fun nextInt(): Int
override suspend fun next(): Int = nextInt() override suspend fun next(): Int = nextInt()
fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() } public fun nextBlock(size: Int): IntArray = IntArray(size) { nextInt() }
} }

View File

@ -3,10 +3,10 @@ package scientifik.kmath.chains
/** /**
* Performance optimized chain for real values * Performance optimized chain for real values
*/ */
abstract class BlockingRealChain : Chain<Double> { public abstract class BlockingRealChain : Chain<Double> {
abstract fun nextDouble(): Double public abstract fun nextDouble(): Double
override suspend fun next(): Double = nextDouble() override suspend fun next(): Double = nextDouble()
fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() } public fun nextBlock(size: Int): DoubleArray = DoubleArray(size) { nextDouble() }
} }

View File

@ -3,20 +3,19 @@ package scientifik.kmath.chains
import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.runningReduce
import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.scan
import kotlinx.coroutines.flow.scanReduce
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.operations.invoke import scientifik.kmath.operations.invoke
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> = space { public fun <T> Flow<T>.cumulativeSum(space: SpaceOperations<T>): Flow<T> =
scanReduce { sum: T, element: T -> sum + element } space { runningReduce { sum, element -> sum + element } }
}
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi
fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = space { public fun <T> Flow<T>.mean(space: Space<T>): Flow<T> = space {
class Accumulator(var sum: T, var num: Int) data class Accumulator(var sum: T, var num: Int)
scan(Accumulator(zero, 0)) { sum, element -> scan(Accumulator(zero, 0)) { sum, element ->
sum.apply { sum.apply {

View File

@ -11,18 +11,18 @@ import scientifik.kmath.structures.asBuffer
/** /**
* Create a [Flow] from buffer * Create a [Flow] from buffer
*/ */
fun <T> Buffer<T>.asFlow(): Flow<T> = iterator().asFlow() public fun <T> Buffer<T>.asFlow(): Flow<T> = iterator().asFlow()
/** /**
* Flat map a [Flow] of [Buffer] into continuous [Flow] of elements * Flat map a [Flow] of [Buffer] into continuous [Flow] of elements
*/ */
@FlowPreview @FlowPreview
fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() } public fun <T> Flow<Buffer<out T>>.spread(): Flow<T> = flatMapConcat { it.asFlow() }
/** /**
* Collect incoming flow into fixed size chunks * Collect incoming flow into fixed size chunks
*/ */
fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<Buffer<T>> = flow { public fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<Buffer<T>> = flow {
require(bufferSize > 0) { "Resulting chunk size must be more than zero" } require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
val list = ArrayList<T>(bufferSize) val list = ArrayList<T>(bufferSize)
var counter = 0 var counter = 0
@ -30,6 +30,7 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
this@chunked.collect { element -> this@chunked.collect { element ->
list.add(element) list.add(element)
counter++ counter++
if (counter == bufferSize) { if (counter == bufferSize) {
val buffer = bufferFactory(bufferSize) { list[it] } val buffer = bufferFactory(bufferSize) { list[it] }
emit(buffer) emit(buffer)
@ -37,15 +38,14 @@ fun <T> Flow<T>.chunked(bufferSize: Int, bufferFactory: BufferFactory<T>): Flow<
counter = 0 counter = 0
} }
} }
if (counter > 0) {
emit(bufferFactory(counter) { list[it] }) if (counter > 0) emit(bufferFactory(counter) { list[it] })
}
} }
/** /**
* Specialized flow chunker for real buffer * Specialized flow chunker for real buffer
*/ */
fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow { public fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
require(bufferSize > 0) { "Resulting chunk size must be more than zero" } require(bufferSize > 0) { "Resulting chunk size must be more than zero" }
if (this@chunked is BlockingRealChain) { if (this@chunked is BlockingRealChain) {
@ -66,9 +66,8 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
counter = 0 counter = 0
} }
} }
if (counter > 0) {
emit(RealBuffer(counter) { array[it] }) if (counter > 0) emit(RealBuffer(counter) { array[it] })
}
} }
} }
@ -76,9 +75,10 @@ fun Flow<Double>.chunked(bufferSize: Int): Flow<RealBuffer> = flow {
* Map a flow to a moving window buffer. The window step is one. * Map a flow to a moving window buffer. The window step is one.
* In order to get different steps, one could use skip operation. * In order to get different steps, one could use skip operation.
*/ */
fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow { public fun <T> Flow<T>.windowed(window: Int): Flow<Buffer<T>> = flow {
require(window > 1) { "Window size must be more than one" } require(window > 1) { "Window size must be more than one" }
val ringBuffer = RingBuffer.boxing<T>(window) val ringBuffer = RingBuffer.boxing<T>(window)
this@windowed.collect { element -> this@windowed.collect { element ->
ringBuffer.push(element) ringBuffer.push(element)
emit(ringBuffer.snapshot()) emit(ringBuffer.snapshot())

View File

@ -10,28 +10,28 @@ import scientifik.kmath.structures.VirtualBuffer
* Thread-safe ring buffer * Thread-safe ring buffer
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
class RingBuffer<T>( public class RingBuffer<T>(
private val buffer: MutableBuffer<T?>, private val buffer: MutableBuffer<T?>,
private var startIndex: Int = 0, private var startIndex: Int = 0,
size: Int = 0 size: Int = 0
) : Buffer<T> { ) : Buffer<T> {
private val mutex: Mutex = Mutex() private val mutex: Mutex = Mutex()
override var size: Int = size public override var size: Int = size
private set private set
override operator fun get(index: Int): T { public override operator fun get(index: Int): T {
require(index >= 0) { "Index must be positive" } require(index >= 0) { "Index must be positive" }
require(index < size) { "Index $index is out of circular buffer size $size" } require(index < size) { "Index $index is out of circular buffer size $size" }
return buffer[startIndex.forward(index)] as T return buffer[startIndex.forward(index)] as T
} }
fun isFull(): Boolean = size == buffer.size public fun isFull(): Boolean = size == buffer.size
/** /**
* Iterator could provide wrong results if buffer is changed in initialization (iteration is safe) * Iterator could provide wrong results if buffer is changed in initialization (iteration is safe)
*/ */
override operator fun iterator(): Iterator<T> = object : AbstractIterator<T>() { public override operator fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
private var count = size private var count = size
private var index = startIndex private var index = startIndex
val copy = buffer.copy() val copy = buffer.copy()
@ -48,23 +48,17 @@ class RingBuffer<T>(
/** /**
* A safe snapshot operation * A safe snapshot operation
*/ */
suspend fun snapshot(): Buffer<T> { public suspend fun snapshot(): Buffer<T> {
mutex.withLock { mutex.withLock {
val copy = buffer.copy() val copy = buffer.copy()
return VirtualBuffer(size) { i -> return VirtualBuffer(size) { i -> copy[startIndex.forward(i)] as T }
copy[startIndex.forward(i)] as T
}
} }
} }
suspend fun push(element: T) { public suspend fun push(element: T) {
mutex.withLock { mutex.withLock {
buffer[startIndex.forward(size)] = element buffer[startIndex.forward(size)] = element
if (isFull()) { if (isFull()) startIndex++ else size++
startIndex++
} else {
size++
}
} }
} }
@ -72,8 +66,8 @@ class RingBuffer<T>(
@Suppress("NOTHING_TO_INLINE") @Suppress("NOTHING_TO_INLINE")
private inline fun Int.forward(n: Int): Int = (this + n) % (buffer.size) private inline fun Int.forward(n: Int): Int = (this + n) % (buffer.size)
companion object { public companion object {
inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> { public inline fun <reified T : Any> build(size: Int, empty: T): RingBuffer<T> {
val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer<T?> val buffer = MutableBuffer.auto(size) { empty } as MutableBuffer<T?>
return RingBuffer(buffer) return RingBuffer(buffer)
} }
@ -81,7 +75,7 @@ class RingBuffer<T>(
/** /**
* Slow yet universal buffer * Slow yet universal buffer
*/ */
fun <T> boxing(size: Int): RingBuffer<T> { public fun <T> boxing(size: Int): RingBuffer<T> {
val buffer: MutableBuffer<T?> = MutableBuffer.boxing(size) { null } val buffer: MutableBuffer<T?> = MutableBuffer.boxing(size) { null }
return RingBuffer(buffer) return RingBuffer(buffer)
} }

View File

@ -118,39 +118,39 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
other: DMatrix<T, C1, C2> other: DMatrix<T, C1, C2>
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce() ): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> = public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
DPoint.coerceUnsafe(context { this@dot dot vector }) DPoint.coerceUnsafe(context { this@dot dot vector })
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> = public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.times(value: T): DMatrix<T, R, C> =
context { this@times.times(value) }.coerce() context { this@times.times(value) }.coerce()
inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> = public inline operator fun <reified R : Dimension, reified C : Dimension> T.times(m: DMatrix<T, R, C>): DMatrix<T, R, C> =
m * this m * this
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> = public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.plus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
context { this@plus + other }.coerce() context { this@plus + other }.coerce()
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> = public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.minus(other: DMatrix<T, C, R>): DMatrix<T, C, R> =
context { this@minus + other }.coerce() context { this@minus + other }.coerce()
inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> = public inline operator fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.unaryMinus(): DMatrix<T, C, R> =
context { this@unaryMinus.unaryMinus() }.coerce() context { this@unaryMinus.unaryMinus() }.coerce()
inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> = public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
context { (this@transpose as Matrix<T>).transpose() }.coerce() context { (this@transpose as Matrix<T>).transpose() }.coerce()
/** /**
* A square unit matrix * A square unit matrix
*/ */
inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j -> public inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
if (i == j) context.elementContext.one else context.elementContext.zero if (i == j) context.elementContext.one else context.elementContext.zero
} }
inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ -> public inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
context.elementContext.zero context.elementContext.zero
} }
companion object { public companion object {
val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real) public val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
} }
} }

View File

@ -12,39 +12,38 @@ import scientifik.kmath.structures.asBuffer
import scientifik.kmath.structures.asIterable import scientifik.kmath.structures.asIterable
import kotlin.math.sqrt import kotlin.math.sqrt
typealias RealPoint = Point<Double> public typealias RealPoint = Point<Double>
fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer()) public fun DoubleArray.asVector(): RealVector = RealVector(this.asBuffer())
fun List<Double>.asVector(): RealVector = RealVector(this.asBuffer()) public fun List<Double>.asVector(): RealVector = RealVector(this.asBuffer())
object VectorL2Norm : Norm<Point<out Number>, Double> { public object VectorL2Norm : Norm<Point<out Number>, Double> {
override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() }) override fun norm(arg: Point<out Number>): Double = sqrt(arg.asIterable().sumByDouble { it.toDouble() })
} }
inline class RealVector(private val point: Point<Double>) : public inline class RealVector(private val point: Point<Double>) :
SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint { SpaceElement<RealPoint, RealVector, VectorSpace<Double, RealField>>, RealPoint {
public override val size: Int get() = point.size
public override val context: VectorSpace<Double, RealField> get() = space(point.size)
override val context: VectorSpace<Double, RealField> get() = space(point.size) public override fun unwrap(): RealPoint = point
override fun unwrap(): RealPoint = point public override fun RealPoint.wrap(): RealVector = RealVector(this)
override fun RealPoint.wrap(): RealVector = RealVector(this)
override val size: Int get() = point.size
override operator fun get(index: Int): Double = point[index] override operator fun get(index: Int): Double = point[index]
override operator fun iterator(): Iterator<Double> = point.iterator() override operator fun iterator(): Iterator<Double> = point.iterator()
companion object { public companion object {
private val spaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf() private val spaceCache: MutableMap<Int, BufferVectorSpace<Double, RealField>> = hashMapOf()
inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector = public inline operator fun invoke(dim: Int, initializer: (Int) -> Double): RealVector =
RealVector(RealBuffer(dim, initializer)) RealVector(RealBuffer(dim, initializer))
operator fun invoke(vararg values: Double): RealVector = values.asVector() public operator fun invoke(vararg values: Double): RealVector = values.asVector()
fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) { public fun space(dim: Int): BufferVectorSpace<Double, RealField> = spaceCache.getOrPut(dim) {
BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) } BufferVectorSpace(dim, RealField) { size, init -> Buffer.real(size, init) }
} }
} }

View File

@ -5,4 +5,4 @@ import scientifik.kmath.structures.RealBuffer
/** /**
* Simplified [RealBuffer] to array comparison * Simplified [RealBuffer] to array comparison
*/ */
fun RealBuffer.contentEquals(vararg doubles: Double) = array.contentEquals(doubles) public fun RealBuffer.contentEquals(vararg doubles: Double): Boolean = array.contentEquals(doubles)

View File

@ -138,11 +138,11 @@ public fun Matrix<Double>.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
} }
public fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j -> public fun Matrix<Double>.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().min() ?: error("Cannot produce min on empty column") columns[j].asIterable().minOrNull() ?: error("Cannot produce min on empty column")
} }
public fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j -> public fun Matrix<Double>.maxByColumn(): RealBuffer = RealBuffer(colNum) { j ->
columns[j].asIterable().max() ?: error("Cannot produce min on empty column") columns[j].asIterable().maxOrNull() ?: error("Cannot produce min on empty column")
} }
public fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j -> public fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j ->
@ -154,6 +154,6 @@ public fun Matrix<Double>.averageByColumn(): RealBuffer = RealBuffer(colNum) { j
*/ */
public fun Matrix<Double>.sum(): Double = elements().map { (_, value) -> value }.sum() public fun Matrix<Double>.sum(): Double = elements().map { (_, value) -> value }.sum()
public fun Matrix<Double>.min(): Double? = elements().map { (_, value) -> value }.min() public fun Matrix<Double>.min(): Double? = elements().map { (_, value) -> value }.minOrNull()
public fun Matrix<Double>.max(): Double? = elements().map { (_, value) -> value }.max() public fun Matrix<Double>.max(): Double? = elements().map { (_, value) -> value }.maxOrNull()
public fun Matrix<Double>.average(): Double = elements().map { (_, value) -> value }.average() public fun Matrix<Double>.average(): Double = elements().map { (_, value) -> value }.average()

View File

@ -36,8 +36,10 @@ public fun <T : Any, C : Ring<T>> Polynomial<T>.value(ring: C, arg: T): T = ring
/** /**
* Represent a polynomial as a context-dependent function * Represent a polynomial as a context-dependent function
*/ */
public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, out C, T> = public fun <T : Any, C : Ring<T>> Polynomial<T>.asMathFunction(): MathFunction<T, C, T> =
MathFunction { arg -> value(this, arg) } object : MathFunction<T, C, T> {
override fun C.invoke(arg: T): T = value(this, arg)
}
/** /**
* Represent the polynomial as a regular context-less function * Represent the polynomial as a regular context-less function

View File

@ -3,13 +3,14 @@ package scientifik.kmath.functions
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
// TODO make fun interface when KT-41770 is fixed
/** /**
* A regular function that could be called only inside specific algebra context * A regular function that could be called only inside specific algebra context
* @param T source type * @param T source type
* @param C source algebra constraint * @param C source algebra constraint
* @param R result type * @param R result type
*/ */
public fun interface MathFunction<T, C : Algebra<T>, R> { public /*fun*/ interface MathFunction<T, C : Algebra<T>, R> {
public operator fun C.invoke(arg: T): R public operator fun C.invoke(arg: T): R
} }

View File

@ -15,15 +15,11 @@ public class SplineInterpolator<T : Comparable<T>>(
public override val algebra: Field<T>, public override val algebra: Field<T>,
public val bufferFactory: MutableBufferFactory<T> public val bufferFactory: MutableBufferFactory<T>
) : PolynomialInterpolator<T> { ) : PolynomialInterpolator<T> {
//TODO possibly optimize zeroed buffers //TODO possibly optimize zeroed buffers
public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra { public override fun interpolatePolynomials(points: XYPointSet<T, T>): PiecewisePolynomial<T> = algebra {
if (points.size < 3) { require(points.size >= 3) { "Can't use spline interpolator with less than 3 points" }
error("Can't use spline interpolator with less than 3 points")
}
insureSorted(points) insureSorted(points)
// Number of intervals. The number of data points is n + 1. // Number of intervals. The number of data points is n + 1.
val n = points.size - 1 val n = points.size - 1
// Differences between knot points // Differences between knot points
@ -34,6 +30,7 @@ public class SplineInterpolator<T : Comparable<T>>(
for (i in 1 until n) { for (i in 1 until n) {
val g = 2.0 * (points.x[i + 1] - points.x[i - 1]) - h[i - 1] * mu[i - 1] val g = 2.0 * (points.x[i + 1] - points.x[i - 1]) - h[i - 1] * mu[i - 1]
mu[i] = h[i] / g mu[i] = h[i] / g
z[i] = z[i] =
(3.0 * (points.y[i + 1] * h[i - 1] - points.x[i] * (points.x[i + 1] - points.x[i - 1]) + points.y[i - 1] * h[i]) / (h[i - 1] * h[i]) (3.0 * (points.y[i + 1] * h[i - 1] - points.x[i] * (points.x[i + 1] - points.x[i - 1]) + points.y[i - 1] * h[i]) / (h[i - 1] * h[i])
- h[i - 1] * z[i - 1]) / g - h[i - 1] * z[i - 1]) / g
@ -54,7 +51,5 @@ public class SplineInterpolator<T : Comparable<T>>(
putLeft(points.x[j], polynomial) putLeft(points.x[j], polynomial)
} }
} }
} }
} }

View File

@ -5,16 +5,16 @@ package scientifik.kmath.histogram
* TODO replace with atomics * TODO replace with atomics
*/ */
expect class LongCounter() { public expect class LongCounter() {
fun decrement() public fun decrement()
fun increment() public fun increment()
fun reset() public fun reset()
fun sum(): Long public fun sum(): Long
fun add(l: Long) public fun add(l: Long)
} }
expect class DoubleCounter() { public expect class DoubleCounter() {
fun reset() public fun reset()
fun sum(): Double public fun sum(): Double
fun add(d: Double) public fun add(d: Double)
} }

View File

@ -11,49 +11,47 @@ import kotlin.contracts.contract
/** /**
* The bin in the histogram. The histogram is by definition always done in the real space * The bin in the histogram. The histogram is by definition always done in the real space
*/ */
interface Bin<T : Any> : Domain<T> { public interface Bin<T : Any> : Domain<T> {
/** /**
* The value of this bin * The value of this bin
*/ */
val value: Number public val value: Number
val center: Point<T> public val center: Point<T>
} }
interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> { public interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
/** /**
* Find existing bin, corresponding to given coordinates * Find existing bin, corresponding to given coordinates
*/ */
operator fun get(point: Point<out T>): B? public operator fun get(point: Point<out T>): B?
/** /**
* Dimension of the histogram * Dimension of the histogram
*/ */
val dimension: Int public val dimension: Int
} }
interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> { public interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
/** /**
* Increment appropriate bin * Increment appropriate bin
*/ */
fun putWithWeight(point: Point<out T>, weight: Double) public fun putWithWeight(point: Point<out T>, weight: Double)
fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0) public fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0)
} }
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point)) public fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point))
fun MutableHistogram<Double, *>.put(vararg point: Number): Unit = public fun MutableHistogram<Double, *>.put(vararg point: Number): Unit =
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point)) public fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point))
public fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
/** /**
* Pass a sequence builder into histogram * Pass a sequence builder into histogram
*/ */
fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit = public fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
fill(sequence(block).asIterable()) fill(sequence(block).asIterable())

View File

@ -7,9 +7,12 @@ import scientifik.kmath.real.asVector
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor
public data class BinDef<T : Comparable<T>>(
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) { public val space: SpaceOperations<Point<T>>,
fun contains(vector: Point<out T>): Boolean { public val center: Point<T>,
public val sizes: Point<T>
) {
public fun contains(vector: Point<out T>): Boolean {
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" } require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
val upper = space { center + sizes / 2.0 } val upper = space { center + sizes / 2.0 }
val lower = space { center - sizes / 2.0 } val lower = space { center - sizes / 2.0 }
@ -18,21 +21,20 @@ data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val c
} }
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> { public class MultivariateBin<T : Comparable<T>>(public val def: BinDef<T>, public override val value: Number) : Bin<T> {
override operator fun contains(point: Point<T>): Boolean = def.contains(point) public override val dimension: Int
override val dimension: Int
get() = def.center.size get() = def.center.size
override val center: Point<T> public override val center: Point<T>
get() = def.center get() = def.center
public override operator fun contains(point: Point<T>): Boolean = def.contains(point)
} }
/** /**
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
*/ */
class RealHistogram( public class RealHistogram(
private val lower: Buffer<Double>, private val lower: Buffer<Double>,
private val upper: Buffer<Double>, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 }
@ -40,7 +42,7 @@ class RealHistogram(
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() } private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() } private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
override val dimension: Int get() = lower.size public override val dimension: Int get() = lower.size
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
init { init {
@ -64,7 +66,7 @@ class RealHistogram(
private fun getValue(index: IntArray): Long = values[index].sum() private fun getValue(index: IntArray): Long = values[index].sum()
fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point)) public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
private fun getDef(index: IntArray): BinDef<Double> { private fun getDef(index: IntArray): BinDef<Double> {
val center = index.mapIndexed { axis, i -> val center = index.mapIndexed { axis, i ->
@ -78,9 +80,9 @@ class RealHistogram(
return BinDef(RealBufferFieldOperations, center, binSize) return BinDef(RealBufferFieldOperations, center, binSize)
} }
fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point)) public fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point))
override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? { public override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
val index = getIndex(point) val index = getIndex(point)
return MultivariateBin(getDef(index), getValue(index)) return MultivariateBin(getDef(index), getValue(index))
} }
@ -90,27 +92,27 @@ class RealHistogram(
// values[index].increment() // values[index].increment()
// } // }
override fun putWithWeight(point: Buffer<out Double>, weight: Double) { public override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
val index = getIndex(point) val index = getIndex(point)
values[index].increment() values[index].increment()
weights[index].add(weight) weights[index].add(weight)
} }
override operator fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) -> public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
MultivariateBin(getDef(index), value.sum()) weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
}.iterator() .iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() } public fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() }
/** /**
* Sum of weights * Sum of weights
*/ */
fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() } public fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() }
companion object { public companion object {
/** /**
* Use it like * Use it like
* ``` * ```
@ -120,9 +122,9 @@ class RealHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram = RealHistogram( public fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram = RealHistogram(
ranges.map { it.start }.asVector(), ranges.map(ClosedFloatingPointRange<Double>::start).asVector(),
ranges.map { it.endInclusive }.asVector() ranges.map(ClosedFloatingPointRange<Double>::endInclusive).asVector()
) )
/** /**
@ -134,10 +136,21 @@ class RealHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram = RealHistogram( public fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram =
ListBuffer(ranges.map { it.first.start }), RealHistogram(
ListBuffer(ranges.map { it.first.endInclusive }), ListBuffer(
ranges.map { it.second }.toIntArray() ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
.map(ClosedFloatingPointRange<Double>::start)
),
ListBuffer(
ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
.map(ClosedFloatingPointRange<Double>::endInclusive)
),
ranges.map(Pair<ClosedFloatingPointRange<Double>, Int>::second).toIntArray()
) )
} }
} }

View File

@ -10,7 +10,7 @@ import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
import kotlin.test.assertTrue import kotlin.test.assertTrue
class MultivariateHistogramTest { internal class MultivariateHistogramTest {
@Test @Test
fun testSinglePutHistogram() { fun testSinglePutHistogram() {
val histogram = RealHistogram.fromRanges( val histogram = RealHistogram.fromRanges(

View File

@ -1,33 +1,37 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
actual class LongCounter { public actual class LongCounter {
private var sum: Long = 0 private var sum: Long = 0L
actual fun decrement() {
public actual fun decrement() {
sum-- sum--
} }
actual fun increment() { public actual fun increment() {
sum++ sum++
} }
actual fun reset() { public actual fun reset() {
sum = 0 sum = 0
} }
actual fun sum(): Long = sum public actual fun sum(): Long = sum
actual fun add(l: Long) {
public actual fun add(l: Long) {
sum += l sum += l
} }
} }
actual class DoubleCounter { public actual class DoubleCounter {
private var sum: Double = 0.0 private var sum: Double = 0.0
actual fun reset() {
public actual fun reset() {
sum = 0.0 sum = 0.0
} }
actual fun sum(): Double = sum public actual fun sum(): Double = sum
actual fun add(d: Double) {
public actual fun add(d: Double) {
sum += d sum += d
} }
} }

View File

@ -3,5 +3,5 @@ package scientifik.kmath.histogram
import java.util.concurrent.atomic.DoubleAdder import java.util.concurrent.atomic.DoubleAdder
import java.util.concurrent.atomic.LongAdder import java.util.concurrent.atomic.LongAdder
actual typealias LongCounter = LongAdder public actual typealias LongCounter = LongAdder
actual typealias DoubleCounter = DoubleAdder public actual typealias DoubleCounter = DoubleAdder

View File

@ -8,25 +8,26 @@ import kotlin.math.floor
//TODO move to common //TODO move to common
class UnivariateBin(val position: Double, val size: Double, val counter: LongCounter = LongCounter()) : Bin<Double> { public class UnivariateBin(
public val position: Double,
public val size: Double,
public val counter: LongCounter = LongCounter()
) : Bin<Double> {
//TODO add weighting //TODO add weighting
override val value: Number get() = counter.sum() public override val value: Number get() = counter.sum()
override val center: RealVector get() = doubleArrayOf(position).asVector() public override val center: RealVector get() = doubleArrayOf(position).asVector()
public override val dimension: Int get() = 1
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) public operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
public override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
override fun contains(point: Buffer<Double>): Boolean = contains(point[0]) internal operator fun inc(): UnivariateBin = this.also { counter.increment() }
internal operator fun inc() = this.also { counter.increment() }
override val dimension: Int get() = 1
} }
/** /**
* Univariate histogram with log(n) bin search speed * Univariate histogram with log(n) bin search speed
*/ */
class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) : public class UnivariateHistogram private constructor(private val factory: (Double) -> UnivariateBin) :
MutableHistogram<Double, UnivariateBin> { MutableHistogram<Double, UnivariateBin> {
private val bins: TreeMap<Double, UnivariateBin> = TreeMap() private val bins: TreeMap<Double, UnivariateBin> = TreeMap()
@ -46,16 +47,16 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
synchronized(this) { bins.put(it.position, it) } synchronized(this) { bins.put(it.position, it) }
} }
override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0]) public override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
override val dimension: Int get() = 1 public override val dimension: Int get() = 1
override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator() public override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
/** /**
* Thread safe put operation * Thread safe put operation
*/ */
fun put(value: Double) { public fun put(value: Double) {
(get(value) ?: createBin(value)).inc() (get(value) ?: createBin(value)).inc()
} }
@ -64,13 +65,13 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
put(point[0]) put(point[0])
} }
companion object { public companion object {
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value -> public fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value ->
val center = start + binSize * floor((value - start) / binSize + 0.5) val center = start + binSize * floor((value - start) / binSize + 0.5)
UnivariateBin(center, binSize) UnivariateBin(center, binSize)
} }
fun custom(borders: DoubleArray): UnivariateHistogram { public fun custom(borders: DoubleArray): UnivariateHistogram {
val sorted = borders.sortedArray() val sorted = borders.sortedArray()
return UnivariateHistogram { value -> return UnivariateHistogram { value ->
@ -79,10 +80,12 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY,
Double.MAX_VALUE Double.MAX_VALUE
) )
value > sorted.last() -> UnivariateBin( value > sorted.last() -> UnivariateBin(
Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
Double.MAX_VALUE Double.MAX_VALUE
) )
else -> { else -> {
val index = (0 until sorted.size).first { value > sorted[it] } val index = (0 until sorted.size).first { value > sorted[it] }
val left = sorted[index] val left = sorted[index]
@ -95,4 +98,4 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
} }
} }
fun UnivariateHistogram.fill(sequence: Iterable<Double>) = sequence.forEach { put(it) } public fun UnivariateHistogram.fill(sequence: Iterable<Double>): Unit = sequence.forEach(::put)

View File

@ -5,19 +5,19 @@ import scientifik.kmath.chains.collect
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.BufferFactory import scientifik.kmath.structures.BufferFactory
interface Sampler<T : Any> { public interface Sampler<T : Any> {
fun sample(generator: RandomGenerator): Chain<T> public fun sample(generator: RandomGenerator): Chain<T>
} }
/** /**
* A distribution of typed objects * A distribution of typed objects
*/ */
interface Distribution<T : Any> : Sampler<T> { public interface Distribution<T : Any> : Sampler<T> {
/** /**
* A probability value for given argument [arg]. * A probability value for given argument [arg].
* For continuous distributions returns PDF * For continuous distributions returns PDF
*/ */
fun probability(arg: T): Double public fun probability(arg: T): Double
/** /**
* Create a chain of samples from this distribution. * Create a chain of samples from this distribution.
@ -28,20 +28,20 @@ interface Distribution<T : Any> : Sampler<T> {
/** /**
* An empty companion. Distribution factories should be written as its extensions * An empty companion. Distribution factories should be written as its extensions
*/ */
companion object public companion object
} }
interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> { public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
/** /**
* Cumulative distribution for ordered parameter (CDF) * Cumulative distribution for ordered parameter (CDF)
*/ */
fun cumulative(arg: T): Double public fun cumulative(arg: T): Double
} }
/** /**
* Compute probability integral in an interval * Compute probability integral in an interval
*/ */
fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double { public fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
require(to > from) require(to > from)
return cumulative(to) - cumulative(from) return cumulative(to) - cumulative(from)
} }
@ -49,7 +49,7 @@ fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Doub
/** /**
* Sample a bunch of values * Sample a bunch of values
*/ */
fun <T : Any> Sampler<T>.sampleBuffer( public fun <T : Any> Sampler<T>.sampleBuffer(
generator: RandomGenerator, generator: RandomGenerator,
size: Int, size: Int,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing bufferFactory: BufferFactory<T> = Buffer.Companion::boxing
@ -57,6 +57,7 @@ fun <T : Any> Sampler<T>.sampleBuffer(
require(size > 1) require(size > 1)
//creating temporary storage once //creating temporary storage once
val tmp = ArrayList<T>(size) val tmp = ArrayList<T>(size)
return sample(generator).collect { chain -> return sample(generator).collect { chain ->
//clear list from previous run //clear list from previous run
tmp.clear() tmp.clear()
@ -72,5 +73,5 @@ fun <T : Any> Sampler<T>.sampleBuffer(
/** /**
* Generate a bunch of samples from real distributions * Generate a bunch of samples from real distributions
*/ */
fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int) = public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int) =
sampleBuffer(generator, size, Buffer.Companion::real) sampleBuffer(generator, size, Buffer.Companion::real)

View File

@ -12,33 +12,29 @@ import kotlin.math.pow
import kotlin.math.sqrt import kotlin.math.sqrt
public abstract class ContinuousSamplerDistribution : Distribution<Double> { public abstract class ContinuousSamplerDistribution : Distribution<Double> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() { private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingRealChain() {
private val sampler = buildCMSampler(generator) private val sampler = buildCMSampler(generator)
override fun nextDouble(): Double = sampler.sample() public override fun nextDouble(): Double = sampler.sample()
public override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
override fun fork(): Chain<Double> = ContinuousSamplerChain(generator.fork())
} }
protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler protected abstract fun buildCMSampler(generator: RandomGenerator): ContinuousSampler
override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator) public override fun sample(generator: RandomGenerator): BlockingRealChain = ContinuousSamplerChain(generator)
} }
public abstract class DiscreteSamplerDistribution : Distribution<Int> { public abstract class DiscreteSamplerDistribution : Distribution<Int> {
private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() { private inner class ContinuousSamplerChain(val generator: RandomGenerator) : BlockingIntChain() {
private val sampler = buildSampler(generator) private val sampler = buildSampler(generator)
override fun nextInt(): Int = sampler.sample() public override fun nextInt(): Int = sampler.sample()
public override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
override fun fork(): Chain<Int> = ContinuousSamplerChain(generator.fork())
} }
protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler protected abstract fun buildSampler(generator: RandomGenerator): DiscreteSampler
override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator) public override fun sample(generator: RandomGenerator): BlockingIntChain = ContinuousSamplerChain(generator)
} }
public enum class NormalSamplerMethod { public enum class NormalSamplerMethod {
@ -58,7 +54,7 @@ public fun Distribution.Companion.normal(
method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat method: NormalSamplerMethod = NormalSamplerMethod.Ziggurat
): Distribution<Double> = object : ContinuousSamplerDistribution() { ): Distribution<Double> = object : ContinuousSamplerDistribution() {
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider: UniformRandomProvider = generator.asUniformRandomProvider() val provider = generator.asUniformRandomProvider()
return normalSampler(method, provider) return normalSampler(method, provider)
} }
@ -76,34 +72,27 @@ public fun Distribution.Companion.normal(
private val norm = sigma * sqrt(PI * 2) private val norm = sigma * sqrt(PI * 2)
override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler { override fun buildCMSampler(generator: RandomGenerator): ContinuousSampler {
val provider: UniformRandomProvider = generator.asUniformRandomProvider() val provider = generator.asUniformRandomProvider()
val normalizedSampler = normalSampler(method, provider) val normalizedSampler = normalSampler(method, provider)
return GaussianSampler(normalizedSampler, mean, sigma) return GaussianSampler(normalizedSampler, mean, sigma)
} }
override fun probability(arg: Double): Double { override fun probability(arg: Double): Double = exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
return exp(-(arg - mean).pow(2) / 2 / sigma2) / norm
}
} }
public fun Distribution.Companion.poisson( public fun Distribution.Companion.poisson(lambda: Double): DiscreteSamplerDistribution =
lambda: Double object : DiscreteSamplerDistribution() {
): DiscreteSamplerDistribution = object : DiscreteSamplerDistribution() { private val computedProb: MutableMap<Int, Double> = hashMapOf(0 to exp(-lambda))
override fun buildSampler(generator: RandomGenerator): DiscreteSampler { override fun buildSampler(generator: RandomGenerator): DiscreteSampler =
return PoissonSampler.of(generator.asUniformRandomProvider(), lambda) PoissonSampler.of(generator.asUniformRandomProvider(), lambda)
}
private val computedProb: HashMap<Int, Double> = hashMapOf(0 to exp(-lambda))
override fun probability(arg: Int): Double { override fun probability(arg: Int): Double {
require(arg >= 0) { "The argument must be >= 0" } require(arg >= 0) { "The argument must be >= 0" }
return if (arg > 40) {
return if (arg > 40)
exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda) exp(-(arg - lambda).pow(2) / 2 / lambda) / sqrt(2 * PI * lambda)
} else { else
computedProb.getOrPut(arg) { computedProb.getOrPut(arg) { probability(arg - 1) * lambda / arg }
probability(arg - 1) * lambda / arg
} }
} }
}
}

View File

@ -27,7 +27,6 @@ include(
":kmath-memory", ":kmath-memory",
":kmath-core", ":kmath-core",
":kmath-functions", ":kmath-functions",
// ":kmath-io",
":kmath-coroutines", ":kmath-coroutines",
":kmath-histograms", ":kmath-histograms",
":kmath-commons", ":kmath-commons",
@ -38,6 +37,6 @@ include(
":kmath-dimensions", ":kmath-dimensions",
":kmath-for-real", ":kmath-for-real",
":kmath-geometry", ":kmath-geometry",
":kmath-ast", // ":kmath-ast",
":examples" ":examples"
) )