Histogram refactor

This commit is contained in:
Alexander Nozik 2021-01-30 11:24:34 +03:00
parent e1ca02dced
commit 5f6c133550
3 changed files with 117 additions and 55 deletions

View File

@ -37,6 +37,7 @@
- Features moved to NDStructure and became transparent. - Features moved to NDStructure and became transparent.
- Capitalization of LUP in many names changed to Lup. - Capitalization of LUP in many names changed to Lup.
- Refactored `NDStructure` algebra to be more simple, preferring under-the-hood conversion to explicit NDStructure types - Refactored `NDStructure` algebra to be more simple, preferring under-the-hood conversion to explicit NDStructure types
- Refactor histograms. They are marked as prototype
### Deprecated ### Deprecated

View File

@ -8,10 +8,10 @@ import kscience.kmath.operations.invoke
import kscience.kmath.structures.* import kscience.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor
public data class BinDef<T : Comparable<T>>( public data class BinDefinition<T : Comparable<T>>(
public val space: SpaceOperations<Point<T>>, public val space: SpaceOperations<Point<T>>,
public val center: Point<T>, public val center: Point<T>,
public val sizes: Point<T> public val sizes: Point<T>,
) { ) {
public fun contains(vector: Point<out T>): Boolean { public fun contains(vector: Point<out T>): Boolean {
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" } require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
@ -22,14 +22,17 @@ public data class BinDef<T : Comparable<T>>(
} }
public class MultivariateBin<T : Comparable<T>>(public val def: BinDef<T>, public override val value: Number) : Bin<T> { public class MultivariateBin<T : Comparable<T>>(
public val definition: BinDefinition<T>,
public override val value: Number,
) : Bin<T> {
public override val dimension: Int public override val dimension: Int
get() = def.center.size get() = definition.center.size
public override val center: Point<T> public override val center: Point<T>
get() = def.center get() = definition.center
public override operator fun contains(point: Point<T>): Boolean = def.contains(point) public override operator fun contains(point: Point<T>): Boolean = definition.contains(point)
} }
/** /**
@ -38,11 +41,11 @@ public class MultivariateBin<T : Comparable<T>>(public val def: BinDef<T>, publi
public class RealHistogram( public class RealHistogram(
private val lower: Buffer<Double>, private val lower: Buffer<Double>,
private val upper: Buffer<Double>, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 },
) : MutableHistogram<Double, MultivariateBin<Double>> { ) : MutableHistogram<Double, MultivariateBin<Double>> {
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() } private val counts: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() } private val values: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
public override val dimension: Int get() = lower.size public override val dimension: Int get() = lower.size
private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = RealBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
@ -65,11 +68,11 @@ public class RealHistogram(
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) } private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
private fun getValue(index: IntArray): Long = values[index].sum() private fun getValue(index: IntArray): Long = counts[index].sum()
public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point)) public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
private fun getDef(index: IntArray): BinDef<Double> { private fun getBinDefinition(index: IntArray): BinDefinition<Double> {
val center = index.mapIndexed { axis, i -> val center = index.mapIndexed { axis, i ->
when (i) { when (i) {
0 -> Double.NEGATIVE_INFINITY 0 -> Double.NEGATIVE_INFINITY
@ -78,14 +81,14 @@ public class RealHistogram(
} }
}.asBuffer() }.asBuffer()
return BinDef(RealBufferFieldOperations, center, binSize) return BinDefinition(RealBufferFieldOperations, center, binSize)
} }
public fun getDef(point: Buffer<out Double>): BinDef<Double> = getDef(getIndex(point)) public fun getBinDefinition(point: Buffer<out Double>): BinDefinition<Double> = getBinDefinition(getIndex(point))
public override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? { public override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
val index = getIndex(point) val index = getIndex(point)
return MultivariateBin(getDef(index), getValue(index)) return MultivariateBin(getBinDefinition(index), getValue(index))
} }
// fun put(point: Point<out Double>){ // fun put(point: Point<out Double>){
@ -95,23 +98,24 @@ public class RealHistogram(
public override fun putWithWeight(point: Buffer<out Double>, weight: Double) { public override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
val index = getIndex(point) val index = getIndex(point)
values[index].increment() counts[index].increment()
weights[index].add(weight) values[index].add(weight)
} }
public override operator fun iterator(): Iterator<MultivariateBin<Double>> = public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) } values.elements().map { (index, value) ->
.iterator() MultivariateBin(getBinDefinition(index), value.sum())
}.iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * NDStructure containing number of events in bins without weights
*/ */
public fun values(): NDStructure<Number> = NDStructure.auto(values.shape) { values[it].sum() } public fun counts(): NDStructure<Long> = NDStructure.auto(counts.shape) { counts[it].sum() }
/** /**
* Sum of weights * NDStructure containing values of bins including weights
*/ */
public fun weights(): NDStructure<Double> = NDStructure.auto(weights.shape) { weights[it].sum() } public fun values(): NDStructure<Double> = NDStructure.auto(values.shape) { values[it].sum() }
public companion object { public companion object {
/** /**

View File

@ -1,8 +1,10 @@
package kscience.kmath.histogram package kscience.kmath.histogram
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
import kscience.kmath.misc.UnstableKMathAPI
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import kscience.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
import kscience.kmath.structures.asSequence
import java.util.* import java.util.*
import kotlin.math.floor import kotlin.math.floor
@ -11,29 +13,36 @@ import kotlin.math.floor
public class UnivariateBin( public class UnivariateBin(
public val position: Double, public val position: Double,
public val size: Double, public val size: Double,
public val counter: LongCounter = LongCounter(),
) : Bin<Double> { ) : Bin<Double> {
//TODO add weighting //internal mutation operations
public override val value: Number get() = counter.sum() internal val counter: LongCounter = LongCounter()
internal val weightCounter: DoubleCounter = DoubleCounter()
/**
* The precise number of events ignoring weighting
*/
public val count: Long get() = counter.sum()
/**
* The value of histogram including weighting
*/
public override val value: Double get() = weightCounter.sum()
public override val center: Point<Double> get() = doubleArrayOf(position).asBuffer() public override val center: Point<Double> get() = doubleArrayOf(position).asBuffer()
public override val dimension: Int get() = 1 public override val dimension: Int get() = 1
public operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) public operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
public override fun contains(point: Buffer<Double>): Boolean = contains(point[0]) public override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
internal operator fun inc(): UnivariateBin = this.also { counter.increment() }
} }
/** /**
* Univariate histogram with log(n) bin search speed * Univariate histogram with log(n) bin search speed
*/ */
public class UnivariateHistogram private constructor( public abstract class UnivariateHistogram(
private val factory: (Double) -> UnivariateBin, protected val bins: TreeMap<Double, UnivariateBin> = TreeMap(),
) : MutableHistogram<Double, UnivariateBin> { ) : Histogram<Double, UnivariateBin> {
private val bins: TreeMap<Double, UnivariateBin> = TreeMap() public operator fun get(value: Double): UnivariateBin? {
private operator fun get(value: Double): UnivariateBin? {
// check ceiling entry and return it if it is what needed // check ceiling entry and return it if it is what needed
val ceil = bins.ceilingEntry(value)?.value val ceil = bins.ceilingEntry(value)?.value
if (ceil != null && value in ceil) return ceil if (ceil != null && value in ceil) return ceil
@ -44,38 +53,38 @@ public class UnivariateHistogram private constructor(
return null return null
} }
private fun createBin(value: Double): UnivariateBin = factory(value).also {
synchronized(this) { bins[it.position] = it }
}
public override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0]) public override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
public override val dimension: Int get() = 1 public override val dimension: Int get() = 1
public override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator() public override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
/**
* Thread safe put operation
*/
public fun put(value: Double) {
(get(value) ?: createBin(value)).inc()
}
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
if (weight != 1.0) TODO("Implement weighting")
put(point[0])
}
public companion object { public companion object {
public fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram = UnivariateHistogram { value -> /**
* Build a histogram with a uniform binning with a start at [start] and a bin size of [binSize]
*/
public fun uniformBuilder(binSize: Double, start: Double = 0.0): UnivariateHistogramBuilder =
UnivariateHistogramBuilder { value ->
val center = start + binSize * floor((value - start) / binSize + 0.5) val center = start + binSize * floor((value - start) / binSize + 0.5)
UnivariateBin(center, binSize) UnivariateBin(center, binSize)
} }
public fun custom(borders: DoubleArray): UnivariateHistogram { /**
* Build and fill a [UnivariateHistogram]. Returns a read-only histogram.
*/
public fun uniform(
binSize: Double,
start: Double = 0.0,
builder: UnivariateHistogramBuilder.() -> Unit,
): UnivariateHistogram = uniformBuilder(binSize, start).apply(builder)
/**
* Create a histogram with custom cell borders
*/
public fun customBuilder(borders: DoubleArray): UnivariateHistogramBuilder {
val sorted = borders.sortedArray() val sorted = borders.sortedArray()
return UnivariateHistogram { value -> return UnivariateHistogramBuilder { value ->
when { when {
value < sorted.first() -> UnivariateBin( value < sorted.first() -> UnivariateBin(
Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY,
@ -96,7 +105,55 @@ public class UnivariateHistogram private constructor(
} }
} }
} }
/**
* Build and fill a histogram with custom borders. Returns a read-only histogram.
*/
public fun custom(
borders: DoubleArray,
builder: UnivariateHistogramBuilder.() -> Unit,
): UnivariateHistogram = customBuilder(borders).apply(builder)
} }
} }
public fun UnivariateHistogram.fill(sequence: Iterable<Double>): Unit = sequence.forEach(::put) public class UnivariateHistogramBuilder(
private val factory: (Double) -> UnivariateBin,
) : UnivariateHistogram(), MutableHistogram<Double, UnivariateBin> {
private fun createBin(value: Double): UnivariateBin = factory(value).also {
synchronized(this) { bins[it.position] = it }
}
/**
* Thread safe put operation
*/
public fun put(value: Double, weight: Double = 1.0) {
(get(value) ?: createBin(value)).apply{
counter.increment()
weightCounter.add(weight)
}
}
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
put(point[0], weight)
}
/**
* Put several items into a single bin
*/
public fun putMany(value: Double, count: Int, weight: Double = count.toDouble()){
(get(value) ?: createBin(value)).apply{
counter.add(count.toLong())
weightCounter.add(weight)
}
}
}
@UnstableKMathAPI
public fun UnivariateHistogramBuilder.fill(items: Iterable<Double>): Unit = items.forEach(::put)
@UnstableKMathAPI
public fun UnivariateHistogramBuilder.fill(array: DoubleArray): Unit = array.forEach(::put)
@UnstableKMathAPI
public fun UnivariateHistogramBuilder.fill(buffer: Buffer<Double>): Unit = buffer.asSequence().forEach(::put)