diff --git a/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt b/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt index a2be2aa9c..e5147b941 100644 --- a/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt +++ b/benchmarks/src/main/kotlin/scientifik/kmath/structures/StructureWriteBenchmark.kt @@ -8,7 +8,7 @@ fun main(args: Array) { val n = 6000 - val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 } + val structure = NDStructure.build(intArrayOf(n, n), DoubleBufferFactory) { 1.0 } structure.mapToBuffer { it + 1 } // warm-up diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt index 79cf92c94..6c1759d54 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/expressions/DiffExpression.kt @@ -2,6 +2,7 @@ package scientifik.kmath.expressions import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.ExtendedFieldOperations import scientifik.kmath.operations.Field import kotlin.properties.ReadOnlyProperty import kotlin.reflect.KProperty diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index a7d9ef7f4..3003ce1bc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -5,7 +5,7 @@ import kotlin.math.* /** * A field for complex numbers */ -object ComplexField : ExtendedField { +object ComplexField : ExtendedFieldOperations, Field { override val zero: Complex = Complex(0.0, 0.0) override val one: Complex = Complex(1.0, 0.0) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index a57688ad8..b80212375 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -5,12 +5,13 @@ import kotlin.math.pow /** * Advanced Number-like field that implements basic operations */ -interface ExtendedField : - Field, +interface ExtendedFieldOperations : + FieldOperations, TrigonometricOperations, PowerOperations, ExponentialOperations +interface ExtendedField : ExtendedFieldOperations, Field /** * Real field element wrapping double. @@ -31,7 +32,7 @@ inline class Real(val value: Double) : FieldElement { * A field for double without boxing. Does not produce appropriate field element */ @Suppress("EXTENSION_SHADOWED_BY_MEMBER") -object RealField : ExtendedField, Norm { +object RealField : Field, ExtendedFieldOperations, Norm { override val zero: Double = 0.0 override fun add(a: Double, b: Double): Double = a + b override fun multiply(a: Double, b: Double): Double = a * b diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index 66ca205a1..ffbbf69dd 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -10,7 +10,7 @@ package scientifik.kmath.operations * It also allows to override behavior for optional operations * */ -interface TrigonometricOperations : Field { +interface TrigonometricOperations : FieldOperations { fun sin(arg: T): T fun cos(arg: T): T diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index 8256b596d..225c00c82 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -109,6 +109,9 @@ inline class ListBuffer(val list: List) : Buffer { fun List.asBuffer() = ListBuffer(this) +@Suppress("FunctionName") +inline fun ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer() + inline class MutableListBuffer(val list: MutableList) : MutableBuffer { override val size: Int @@ -154,9 +157,12 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer { override fun iterator(): Iterator = array.iterator() override fun copy(): MutableBuffer = DoubleBuffer(array.copyOf()) - } +@Suppress("FunctionName") +inline fun DoubleBuffer(size: Int, init: (Int) -> Double) = DoubleBuffer(DoubleArray(size) { init(it) }) + + /** * Transform buffer of doubles into array for high performance operations */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 0a8f1de88..370dc646e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -1,16 +1,14 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.ExponentialOperations -import scientifik.kmath.operations.ExtendedField -import scientifik.kmath.operations.PowerOperations -import scientifik.kmath.operations.TrigonometricOperations +import scientifik.kmath.operations.* -interface ExtendedNDField, N : NDStructure> : +interface ExtendedNDField> : NDField, TrigonometricOperations, PowerOperations, ExponentialOperations + where F : ExtendedFieldOperations, F : Field ///** diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 5f16f0444..2cf9af6fb 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -22,6 +22,33 @@ interface NDStructure { else -> st1.elements().all { (index, value) -> value == st2[index] } } } + + /** + * Create a NDStructure with explicit buffer factory + * + * Strides should be reused if possible + */ + fun build( + strides: Strides, + bufferFactory: BufferFactory = Buffer.Companion::boxing, + initializer: (IntArray) -> T + ) = + BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) + + /** + * Inline create NDStructure with non-boxing buffer implementation if it is possible + */ + inline fun auto(strides: Strides, crossinline initializer: (IntArray) -> T) = + BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) + + fun build( + shape: IntArray, + bufferFactory: BufferFactory = Buffer.Companion::boxing, + initializer: (IntArray) -> T + ) = build(DefaultStrides(shape), bufferFactory, initializer) + + inline fun auto(shape: IntArray, crossinline initializer: (IntArray) -> T) = + auto(DefaultStrides(shape), initializer) } } @@ -200,34 +227,6 @@ inline fun NDStructure.mapToBuffer( } } -/** - * Create a NDStructure with explicit buffer factory - * - * Strides should be reused if possible - */ -fun ndStructure( - strides: Strides, - bufferFactory: BufferFactory = Buffer.Companion::boxing, - initializer: (IntArray) -> T -) = - BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) - -/** - * Inline create NDStructure with non-boxing buffer implementation if it is possible - */ -inline fun inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = - BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) - -fun ndStructure( - shape: IntArray, - bufferFactory: BufferFactory = Buffer.Companion::boxing, - initializer: (IntArray) -> T -) = - ndStructure(DefaultStrides(shape), bufferFactory, initializer) - -inline fun inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = - inlineNDStructure(DefaultStrides(shape), initializer) - /** * Mutable ND buffer based on linear [autoBuffer] */ @@ -245,33 +244,10 @@ class MutableBufferNDStructure( override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value) } -/** - * The same as [inlineNDStructure], but mutable - */ -fun mutableNdStructure( - strides: Strides, - bufferFactory: MutableBufferFactory = MutableBuffer.Companion::boxing, - initializer: (IntArray) -> T -) = - MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) - -inline fun inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = - MutableBufferNDStructure(strides, MutableBuffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) }) - -fun mutableNdStructure( - shape: IntArray, - bufferFactory: MutableBufferFactory = MutableBuffer.Companion::boxing, - initializer: (IntArray) -> T -) = - mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer) - -inline fun inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = - inlineMutableNdStructure(DefaultStrides(shape), initializer) - inline fun NDStructure.combine( struct: NDStructure, crossinline block: (T, T) -> T ): NDStructure { if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") - return inlineNdStructure(shape) { block(this[it], struct[it]) } + return NDStructure.auto(shape) { block(this[it], struct[it]) } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 42e74a27f..b56f42717 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -1,110 +1,153 @@ package scientifik.kmath.structures -import scientifik.kmath.operations.ExtendedField +import scientifik.kmath.operations.ExtendedFieldOperations +import scientifik.kmath.operations.Field import kotlin.math.* + /** * A simple field over linear buffers of [Double] */ -class RealBufferField(val size: Int) : ExtendedField> { - override val zero: Buffer = Buffer.DoubleBufferFactory(size) { 0.0 } - - override val one: Buffer = Buffer.DoubleBufferFactory(size) { 1.0 } - +object RealBufferFieldOperations : ExtendedFieldOperations> { override fun add(a: Buffer, b: Buffer): DoubleBuffer { - require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } - require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " } + require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(size) { aArray[it] + bArray[it] }) + DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) } else { - DoubleBuffer(DoubleArray(size) { a[it] + b[it] }) + DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) } } override fun multiply(a: Buffer, k: Number): DoubleBuffer { - require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } val kValue = k.toDouble() return if (a is DoubleBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(size) { aArray[it] * kValue }) + DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) } else { - DoubleBuffer(DoubleArray(size) { a[it] * kValue }) + DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) } } override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { - require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } - require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " } + require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(size) { aArray[it] * bArray[it] }) + DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) } else { - DoubleBuffer(DoubleArray(size) { a[it] * b[it] }) + DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) } } override fun divide(a: Buffer, b: Buffer): DoubleBuffer { - require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } - require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " } + require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(size) { aArray[it] / bArray[it] }) + DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) } else { - DoubleBuffer(DoubleArray(size) { a[it] / b[it] }) + DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) } } - override fun sin(arg: Buffer): Buffer { - require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " } + override fun sin(arg: Buffer): DoubleBuffer { return if (arg is DoubleBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(size) { sin(array[it]) }) + DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else { - DoubleBuffer(DoubleArray(size) { sin(arg[it]) }) + DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } } - override fun cos(arg: Buffer): Buffer { - require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " } + override fun cos(arg: Buffer): DoubleBuffer { return if (arg is DoubleBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(size) { cos(array[it]) }) + DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) } else { - DoubleBuffer(DoubleArray(size) { cos(arg[it]) }) + DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) } } - override fun power(arg: Buffer, pow: Number): Buffer { - require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " } + override fun power(arg: Buffer, pow: Number): DoubleBuffer { return if (arg is DoubleBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(size) { array[it].pow(pow.toDouble()) }) + DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) } else { - DoubleBuffer(DoubleArray(size) { arg[it].pow(pow.toDouble()) }) + DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) } } - override fun exp(arg: Buffer): Buffer { - require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " } + override fun exp(arg: Buffer): DoubleBuffer { return if (arg is DoubleBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(size) { exp(array[it]) }) + DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) } else { - DoubleBuffer(DoubleArray(size) { exp(arg[it]) }) + DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) } } - override fun ln(arg: Buffer): Buffer { - require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " } + override fun ln(arg: Buffer): DoubleBuffer { return if (arg is DoubleBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(size) { ln(array[it]) }) + DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) } else { - DoubleBuffer(DoubleArray(size) { ln(arg[it]) }) + DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } } +} + +class RealBufferField(val size: Int) : Field>, ExtendedFieldOperations> { + + override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } + + override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } + + override fun add(a: Buffer, b: Buffer): DoubleBuffer { + require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } + return RealBufferFieldOperations.add(a, b) + } + + override fun multiply(a: Buffer, k: Number): DoubleBuffer { + require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } + return RealBufferFieldOperations.multiply(a, k) + } + + override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } + return RealBufferFieldOperations.multiply(a, b) + } + + + override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } + return RealBufferFieldOperations.divide(a, b) + } + + override fun sin(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.sin(arg) + } + + override fun cos(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.cos(arg) + } + + override fun power(arg: Buffer, pow: Number): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.power(arg, pow) + } + + override fun exp(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.exp(arg) + } + + override fun ln(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.ln(arg) + } + } \ No newline at end of file diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt index 9a470014a..9c7de3303 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt @@ -2,9 +2,9 @@ package scientifik.kmath.histogram /* * Common representation for atomic counters + * TODO replace with atomics */ - expect class LongCounter() { fun decrement() fun increment() diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index c5fc67b5c..329af72a1 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -2,12 +2,8 @@ package scientifik.kmath.histogram import scientifik.kmath.linear.Point import scientifik.kmath.structures.ArrayBuffer -import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.DoubleBuffer - -typealias RealPoint = Buffer - /** * A simple geometric domain * TODO move to geometry module @@ -42,13 +38,14 @@ interface Histogram> : Iterable { } -interface MutableHistogram> : - Histogram { +interface MutableHistogram> : Histogram { /** * Increment appropriate bin */ - fun put(point: Point, weight: Double = 1.0) + fun putWithWeight(point: Point, weight: Double) + + fun put(point: Point) = putWithWeight(point, 1.0) } fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt deleted file mode 100644 index 6c6614851..000000000 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt +++ /dev/null @@ -1,64 +0,0 @@ -package scientifik.kmath.histogram - -import scientifik.kmath.linear.Point -import scientifik.kmath.linear.Vector -import scientifik.kmath.operations.Space -import scientifik.kmath.structures.NDStructure -import scientifik.kmath.structures.asSequence - -data class BinTemplate>(val center: Vector, val sizes: Point) { - fun contains(vector: Point): Boolean { - if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") - val upper = center.context.run { center + sizes / 2.0 } - val lower = center.context.run { center - sizes / 2.0 } - return vector.asSequence().mapIndexed { i, value -> - value in lower[i]..upper[i] - }.all { it } - } -} - -/** - * A space to perform arithmetic operations on histograms - */ -interface HistogramSpace, H : Histogram> : Space { - /** - * Rules for performing operations on bins - */ - val binSpace: Space> -} - -class PhantomBin>(val template: BinTemplate, override val value: Number) : - Bin { - - override fun contains(vector: Point): Boolean = template.contains(vector) - - override val dimension: Int - get() = template.center.size - - override val center: Point - get() = template.center - -} - -/** - * Immutable histogram with explicit structure for content and additional external bin description. - * Bin search is slow, but full histogram algebra is supported. - * @param bins transform a template into structure index - */ -class PhantomHistogram>( - val bins: Map, IntArray>, - val data: NDStructure -) : Histogram> { - - override val dimension: Int - get() = data.dimension - - override fun iterator(): Iterator> = - bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator() - - override fun get(point: Point): PhantomBin? { - val template = bins.keys.find { it.contains(point) } - return template?.let { PhantomBin(it, data[bins[it]!!]) } - } - -} \ No newline at end of file diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt similarity index 57% rename from kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt rename to kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index 869f9a02e..75296ba27 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -1,45 +1,66 @@ package scientifik.kmath.histogram +import scientifik.kmath.linear.Point import scientifik.kmath.linear.toVector +import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.structures.* import kotlin.math.floor -private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] }) -private inline fun Buffer.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence = - (0 until size).asSequence().map { mapper(it, get(it)) } +data class BinDef>(val space: SpaceOperations>, val center: Point, val sizes: Point) { + fun contains(vector: Point): Boolean { + if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") + val upper = space.run { center + sizes / 2.0 } + val lower = space.run { center - sizes / 2.0 } + return vector.asSequence().mapIndexed { i, value -> + value in lower[i]..upper[i] + }.all { it } + } +} + + +class MultivariateBin>(val def: BinDef, override val value: Number) : Bin { + + override fun contains(vector: Point): Boolean = def.contains(vector) + + override val dimension: Int + get() = def.center.size + + override val center: Point + get() = def.center + +} /** * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. */ -class FastHistogram( - private val lower: RealPoint, - private val upper: RealPoint, +class RealHistogram( + private val lower: Buffer, + private val upper: Buffer, private val binNums: IntArray = IntArray(lower.size) { 20 } -) : MutableHistogram> { +) : MutableHistogram> { private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) - private val values: NDStructure = inlineNDStructure(strides) { LongCounter() } + private val values: NDStructure = NDStructure.auto(strides) { LongCounter() } //private val weight: NDStructure = ndStructure(strides){null} - //TODO optimize binSize performance if needed - private val binSize: RealPoint = - ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList()) + + override val dimension: Int get() = lower.size + + + private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } init { // argument checks if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") if (lower.size != binNums.size) error("Dimension mismatch in bin count.") - if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive") + if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive") } - override val dimension: Int get() = lower.size - - /** * Get internal [NDStructure] bin index for given axis */ @@ -61,49 +82,41 @@ class FastHistogram( return getValue(getIndex(point)) } - private fun getTemplate(index: IntArray): BinTemplate { + private fun getDef(index: IntArray): BinDef { val center = index.mapIndexed { axis, i -> when (i) { 0 -> Double.NEGATIVE_INFINITY strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis] } - }.toVector() - return BinTemplate(center, binSize) + }.asBuffer() + return BinDef(RealBufferFieldOperations, center, binSize) } - fun getTemplate(point: Buffer): BinTemplate { - return getTemplate(getIndex(point)) + fun getDef(point: Buffer): BinDef { + return getDef(getIndex(point)) } - override fun get(point: Buffer): PhantomBin? { + override fun get(point: Buffer): MultivariateBin? { val index = getIndex(point) - return PhantomBin(getTemplate(index), getValue(index)) + return MultivariateBin(getDef(index), getValue(index)) } - override fun put(point: Buffer, weight: Double) { + override fun putWithWeight(point: Buffer, weight: Double) { if (weight != 1.0) TODO("Implement weighting") val index = getIndex(point) values[index].increment() } - override fun iterator(): Iterator> = values.elements().map { (index, value) -> - PhantomBin(getTemplate(index), value.sum()) + override fun iterator(): Iterator> = values.elements().map { (index, value) -> + MultivariateBin(getDef(index), value.sum()) }.iterator() /** * Convert this histogram into NDStructure containing bin values but not bin descriptions */ fun asNDStructure(): NDStructure { - return inlineNdStructure(this.values.shape) { values[it].sum() } - } - - /** - * Create a phantom lightweight immutable copy of this histogram - */ - fun asPhantomHistogram(): PhantomHistogram { - val binTemplates = values.elements().associate { (index, _) -> getTemplate(index) to index } - return PhantomHistogram(binTemplates, asNDStructure()) + return NDStructure.auto(this.values.shape) { values[it].sum() } } companion object { @@ -117,8 +130,8 @@ class FastHistogram( *) *``` */ - fun fromRanges(vararg ranges: ClosedFloatingPointRange): FastHistogram { - return FastHistogram( + fun fromRanges(vararg ranges: ClosedFloatingPointRange): RealHistogram { + return RealHistogram( ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector() ) @@ -133,8 +146,8 @@ class FastHistogram( *) *``` */ - fun fromRanges(vararg ranges: Pair, Int>): FastHistogram { - return FastHistogram( + fun fromRanges(vararg ranges: Pair, Int>): RealHistogram { + return RealHistogram( ListBuffer(ranges.map { it.first.start }), ListBuffer(ranges.map { it.first.endInclusive }), ranges.map { it.second }.toIntArray() diff --git a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt index d59409e76..678b8eba7 100644 --- a/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt +++ b/kmath-histograms/src/commonTest/kotlin/scietifik/kmath/histogram/MultivariateHistogramTest.kt @@ -1,6 +1,6 @@ package scietifik.kmath.histogram -import scientifik.kmath.histogram.FastHistogram +import scientifik.kmath.histogram.RealHistogram import scientifik.kmath.histogram.fill import scientifik.kmath.histogram.put import scientifik.kmath.linear.Vector @@ -13,7 +13,7 @@ import kotlin.test.assertTrue class MultivariateHistogramTest { @Test fun testSinglePutHistogram() { - val histogram = FastHistogram.fromRanges( + val histogram = RealHistogram.fromRanges( (-1.0..1.0), (-1.0..1.0) ) @@ -26,7 +26,7 @@ class MultivariateHistogramTest { @Test fun testSequentialPut() { - val histogram = FastHistogram.fromRanges( + val histogram = RealHistogram.fromRanges( (-1.0..1.0), (-1.0..1.0), (-1.0..1.0) diff --git a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt index f7c4e7fa4..53c2da641 100644 --- a/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt +++ b/kmath-histograms/src/jvmMain/kotlin/scientifik/kmath/histogram/UnivariateHistogram.kt @@ -59,7 +59,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U (get(value) ?: createBin(value)).inc() } - override fun put(point: Buffer, weight: Double) { + override fun putWithWeight(point: Buffer, weight: Double) { if (weight != 1.0) TODO("Implement weighting") put(point[0]) }