WIP histograms refactor

This commit is contained in:
Alexander Nozik 2021-02-11 11:18:16 +03:00
parent a384b323c3
commit 2fdba53911
10 changed files with 298 additions and 182 deletions

View File

@ -1,4 +1,10 @@
plugins { id("ru.mipt.npm.mpp") } plugins {
id("ru.mipt.npm.mpp")
}
kscience {
useAtomic()
}
kotlin.sourceSets { kotlin.sourceSets {
commonMain { commonMain {

View File

@ -1,20 +1,47 @@
package kscience.kmath.histogram package kscience.kmath.histogram
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kscience.kmath.operations.Space
/* /*
* Common representation for atomic counters * Common representation for atomic counters
* TODO replace with atomics * TODO replace with atomics
*/ */
public expect class LongCounter() { public interface Counter<T : Any> {
public fun decrement() public fun add(delta: T)
public fun increment() public val value: T
public fun reset()
public fun sum(): Long
public fun add(l: Long)
} }
public expect class DoubleCounter() { public class IntCounter : Counter<Int> {
public fun reset() private val innerValue = atomic(0)
public fun sum(): Double
public fun add(d: Double) override fun add(delta: Int) {
innerValue += delta
} }
override val value: Int get() = innerValue.value
}
public class LongCounter : Counter<Long> {
private val innerValue = atomic(0L)
override fun add(delta: Long) {
innerValue += delta
}
override val value: Long get() = innerValue.value
}
public class ObjectCounter<T : Any>(public val space: Space<T>) : Counter<T> {
private val innerValue = atomic(space.zero)
override fun add(delta: T) {
innerValue.getAndUpdate { space.run { it + delta } }
}
override val value: T get() = innerValue.value
}

View File

@ -6,7 +6,7 @@ import kscience.kmath.structures.ArrayBuffer
import kscience.kmath.structures.RealBuffer import kscience.kmath.structures.RealBuffer
/** /**
* The bin in the histogram. The histogram is by definition always done in the real space * The binned data element. Could be a histogram bin with a number of counts or an artificial construct
*/ */
public interface Bin<T : Any> : Domain<T> { public interface Bin<T : Any> : Domain<T> {
/** /**
@ -17,7 +17,7 @@ public interface Bin<T : Any> : Domain<T> {
public val center: Point<T> public val center: Point<T>
} }
public interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> { public interface Histogram<T : Any, out B : Bin<T>> {
/** /**
* Find existing bin, corresponding to given coordinates * Find existing bin, corresponding to given coordinates
*/ */
@ -27,9 +27,11 @@ public interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
* Dimension of the histogram * Dimension of the histogram
*/ */
public val dimension: Int public val dimension: Int
public val bins: Collection<B>
} }
public interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> { public interface HistogramBuilder<T : Any, out B : Bin<T>> : Histogram<T, B> {
/** /**
* Increment appropriate bin * Increment appropriate bin
@ -39,16 +41,16 @@ public interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
public fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0) public fun put(point: Point<out T>): Unit = putWithWeight(point, 1.0)
} }
public fun <T : Any> MutableHistogram<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point)) public fun <T : Any> HistogramBuilder<T, *>.put(vararg point: T): Unit = put(ArrayBuffer(point))
public fun MutableHistogram<Double, *>.put(vararg point: Number): Unit = public fun HistogramBuilder<Double, *>.put(vararg point: Number): Unit =
put(RealBuffer(point.map { it.toDouble() }.toDoubleArray())) put(RealBuffer(point.map { it.toDouble() }.toDoubleArray()))
public fun MutableHistogram<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point)) public fun HistogramBuilder<Double, *>.put(vararg point: Double): Unit = put(RealBuffer(point))
public fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) } public fun <T : Any> HistogramBuilder<T, *>.fill(sequence: Iterable<Point<T>>): Unit = sequence.forEach { put(it) }
/** /**
* Pass a sequence builder into histogram * Pass a sequence builder into histogram
*/ */
public fun <T : Any> MutableHistogram<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit = public fun <T : Any> HistogramBuilder<T, *>.fill(block: suspend SequenceScope<Point<T>>.() -> Unit): Unit =
fill(sequence(block).asIterable()) fill(sequence(block).asIterable())

View File

@ -43,7 +43,7 @@ public class RealHistogram(
private val lower: Buffer<Double>, private val lower: Buffer<Double>,
private val upper: Buffer<Double>, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 }, private val binNums: IntArray = IntArray(lower.size) { 20 },
) : MutableHistogram<Double, MultivariateBin<Double>> { ) : HistogramBuilder<Double, MultivariateBin<Double>> {
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
private val counts: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() } private val counts: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
private val values: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() } private val values: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
@ -88,7 +88,8 @@ public class RealHistogram(
return MultivariateBinDefinition(RealBufferFieldOperations, center, binSize) return MultivariateBinDefinition(RealBufferFieldOperations, center, binSize)
} }
public fun getBinDefinition(point: Buffer<out Double>): MultivariateBinDefinition<Double> = getBinDefinition(getIndex(point)) public fun getBinDefinition(point: Buffer<out Double>): MultivariateBinDefinition<Double> =
getBinDefinition(getIndex(point))
public override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? { public override operator fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
val index = getIndex(point) val index = getIndex(point)
@ -106,10 +107,10 @@ public class RealHistogram(
values[index].add(weight) values[index].add(weight)
} }
public override operator fun iterator(): Iterator<MultivariateBin<Double>> = override val bins: Collection<MultivariateBin<Double>>
strides.indices().map { index-> get() = strides.indices().map { index ->
MultivariateBin(getBinDefinition(index), counts[index].sum(), values[index].sum()) MultivariateBin(getBinDefinition(index), counts[index].sum(), values[index].sum())
}.iterator() }.toList()
/** /**
* NDStructure containing number of events in bins without weights * NDStructure containing number of events in bins without weights

View File

@ -16,7 +16,7 @@ internal class MultivariateHistogramTest {
(-1.0..1.0) (-1.0..1.0)
) )
histogram.put(0.55, 0.55) histogram.put(0.55, 0.55)
val bin = histogram.find { it.value.toInt() > 0 } ?: fail() val bin = histogram.bins.find { it.value.toInt() > 0 } ?: fail()
assertTrue { bin.contains(RealVector(0.55, 0.55)) } assertTrue { bin.contains(RealVector(0.55, 0.55)) }
assertTrue { bin.contains(RealVector(0.6, 0.5)) } assertTrue { bin.contains(RealVector(0.6, 0.5)) }
assertFalse { bin.contains(RealVector(-0.55, 0.55)) } assertFalse { bin.contains(RealVector(-0.55, 0.55)) }
@ -40,6 +40,6 @@ internal class MultivariateHistogramTest {
yield(RealVector(nextDouble(), nextDouble(), nextDouble())) yield(RealVector(nextDouble(), nextDouble(), nextDouble()))
} }
} }
assertEquals(n, histogram.sumBy { it.value.toInt() }) assertEquals(n, histogram.bins.sumBy { it.value.toInt() })
} }
} }

View File

@ -1,37 +0,0 @@
package kscience.kmath.histogram
public actual class LongCounter {
private var sum: Long = 0L
public actual fun decrement() {
sum--
}
public actual fun increment() {
sum++
}
public actual fun reset() {
sum = 0
}
public actual fun sum(): Long = sum
public actual fun add(l: Long) {
sum += l
}
}
public actual class DoubleCounter {
private var sum: Double = 0.0
public actual fun reset() {
sum = 0.0
}
public actual fun sum(): Double = sum
public actual fun add(d: Double) {
sum += d
}
}

View File

@ -0,0 +1,33 @@
package kscience.kmath.histogram
/**
* Univariate histogram with log(n) bin search speed
*/
//private abstract class AbstractUnivariateHistogram<B: UnivariateBin>{
//
// public abstract val bins: TreeMap<Double, B>
//
// public open operator fun get(value: Double): B? {
// // check ceiling entry and return it if it is what needed
// val ceil = bins.ceilingEntry(value)?.value
// if (ceil != null && value in ceil) return ceil
// //check floor entry
// val floor = bins.floorEntry(value)?.value
// if (floor != null && value in floor) return floor
// //neither is valid, not found
// return null
// }
// public override operator fun get(point: Buffer<out Double>): B? = get(point[0])
//
// public override val dimension: Int get() = 1
//
// public override operator fun iterator(): Iterator<B> = bins.values.iterator()
//
// public companion object {
// }
//}

View File

@ -1,7 +0,0 @@
package kscience.kmath.histogram
import java.util.concurrent.atomic.DoubleAdder
import java.util.concurrent.atomic.LongAdder
public actual typealias LongCounter = LongAdder
public actual typealias DoubleCounter = DoubleAdder

View File

@ -6,71 +6,47 @@ import kscience.kmath.operations.SpaceElement
import kscience.kmath.structures.Buffer import kscience.kmath.structures.Buffer
import kscience.kmath.structures.asBuffer import kscience.kmath.structures.asBuffer
import kscience.kmath.structures.asSequence import kscience.kmath.structures.asSequence
import java.util.*
import kotlin.math.floor
//TODO move to common public data class UnivariateHistogramBinDefinition(
val position: Double,
val size: Double,
) : Comparable<UnivariateHistogramBinDefinition> {
override fun compareTo(other: UnivariateHistogramBinDefinition): Int = this.position.compareTo(other.position)
}
public class UnivariateBin( public interface UnivariateBin : Bin<Double> {
public val position: Double, public val def: UnivariateHistogramBinDefinition
public val size: Double,
) : Bin<Double> {
//internal mutation operations
internal val counter: LongCounter = LongCounter()
internal val weightCounter: DoubleCounter = DoubleCounter()
/** public val position: Double get() = def.position
* The precise number of events ignoring weighting public val size: Double get() = def.size
*/
public val count: Long get() = counter.sum()
/** /**
* The value of histogram including weighting * The value of histogram including weighting
*/ */
public override val value: Double get() = weightCounter.sum() public override val value: Double
/**
* Standard deviation of the bin value. Zero if not applicable
*/
public val standardDeviation: Double
public override val center: Point<Double> get() = doubleArrayOf(position).asBuffer() public override val center: Point<Double> get() = doubleArrayOf(position).asBuffer()
public override val dimension: Int get() = 1 public override val dimension: Int get() = 1
public operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
public override fun contains(point: Buffer<Double>): Boolean = contains(point[0]) public override fun contains(point: Buffer<Double>): Boolean = contains(point[0])
} }
/** public operator fun UnivariateBin.contains(value: Double): Boolean =
* Univariate histogram with log(n) bin search speed value in (position - size / 2)..(position + size / 2)
*/
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public abstract class UnivariateHistogram protected constructor( public interface UnivariateHistogram : Histogram<Double, UnivariateBin>,
protected val bins: TreeMap<Double, UnivariateBin> = TreeMap(), SpaceElement<UnivariateHistogram, UnivariateHistogramSpace> {
) : Histogram<Double, UnivariateBin>, SpaceElement<UnivariateHistogram, UnivariateHistogramSpace> { public operator fun get(value: Double): UnivariateBin?
public operator fun get(value: Double): UnivariateBin? {
// check ceiling entry and return it if it is what needed
val ceil = bins.ceilingEntry(value)?.value
if (ceil != null && value in ceil) return ceil
//check floor entry
val floor = bins.floorEntry(value)?.value
if (floor != null && value in floor) return floor
//neither is valid, not found
return null
}
public override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0]) public override operator fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
public override val dimension: Int get() = 1
public override operator fun iterator(): Iterator<UnivariateBin> = bins.values.iterator()
public companion object { public companion object {
/**
* Build a histogram with a uniform binning with a start at [start] and a bin size of [binSize]
*/
public fun uniformBuilder(binSize: Double, start: Double = 0.0): UnivariateHistogramBuilder =
UnivariateHistogramSpace { value ->
val center = start + binSize * floor((value - start) / binSize + 0.5)
UnivariateBin(center, binSize)
}.builder()
/** /**
* Build and fill a [UnivariateHistogram]. Returns a read-only histogram. * Build and fill a [UnivariateHistogram]. Returns a read-only histogram.
*/ */
@ -78,35 +54,7 @@ public abstract class UnivariateHistogram protected constructor(
binSize: Double, binSize: Double,
start: Double = 0.0, start: Double = 0.0,
builder: UnivariateHistogramBuilder.() -> Unit, builder: UnivariateHistogramBuilder.() -> Unit,
): UnivariateHistogram = uniformBuilder(binSize, start).apply(builder) ): UnivariateHistogram = UnivariateHistogramSpace.uniform(binSize, start).produce(builder)
/**
* Create a histogram with custom cell borders
*/
public fun customBuilder(borders: DoubleArray): UnivariateHistogramBuilder {
val sorted = borders.sortedArray()
return UnivariateHistogramSpace { value ->
when {
value < sorted.first() -> UnivariateBin(
Double.NEGATIVE_INFINITY,
Double.MAX_VALUE
)
value > sorted.last() -> UnivariateBin(
Double.POSITIVE_INFINITY,
Double.MAX_VALUE
)
else -> {
val index = sorted.indices.first { value > sorted[it] }
val left = sorted[index]
val right = sorted[index + 1]
UnivariateBin((left + right) / 2, (right - left))
}
}
}.builder()
}
/** /**
* Build and fill a histogram with custom borders. Returns a read-only histogram. * Build and fill a histogram with custom borders. Returns a read-only histogram.
@ -114,41 +62,24 @@ public abstract class UnivariateHistogram protected constructor(
public fun custom( public fun custom(
borders: DoubleArray, borders: DoubleArray,
builder: UnivariateHistogramBuilder.() -> Unit, builder: UnivariateHistogramBuilder.() -> Unit,
): UnivariateHistogram = customBuilder(borders).apply(builder) ): UnivariateHistogram = UnivariateHistogramSpace.custom(borders).produce(builder)
} }
} }
public class UnivariateHistogramBuilder internal constructor( public interface UnivariateHistogramBuilder {
override val context: UnivariateHistogramSpace,
) : UnivariateHistogram(), MutableHistogram<Double, UnivariateBin> {
private fun createBin(value: Double): UnivariateBin = context.binFactory(value).also {
synchronized(this) { bins[it.position] = it }
}
/** /**
* Thread safe put operation * Thread safe put operation
*/ */
public fun put(value: Double, weight: Double = 1.0) { public fun put(value: Double, weight: Double = 1.0)
(get(value) ?: createBin(value)).apply { public fun putWithWeight(point: Buffer<out Double>, weight: Double)
counter.increment()
weightCounter.add(weight)
}
}
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
put(point[0], weight)
}
/** /**
* Put several items into a single bin * Put several items into a single bin
*/ */
public fun putMany(value: Double, count: Int, weight: Double = count.toDouble()) { public fun putMany(value: Double, count: Int, weight: Double = count.toDouble())
(get(value) ?: createBin(value)).apply {
counter.add(count.toLong()) public fun build(): UnivariateHistogram
weightCounter.add(weight)
}
}
} }
@UnstableKMathAPI @UnstableKMathAPI

View File

@ -1,12 +1,109 @@
package kscience.kmath.histogram package kscience.kmath.histogram
import kscience.kmath.operations.Space import kscience.kmath.operations.Space
import kscience.kmath.structures.Buffer
import java.util.*
import kotlin.math.abs
import kotlin.math.sqrt
public class UnivariateHistogramSpace(public val binFactory: (Double) -> UnivariateBin) : Space<UnivariateHistogram> { private fun <B : UnivariateBin> TreeMap<Double, B>.getBin(value: Double): B? {
// check ceiling entry and return it if it is what needed
val ceil = ceilingEntry(value)?.value
if (ceil != null && value in ceil) return ceil
//check floor entry
val floor = floorEntry(value)?.value
if (floor != null && value in floor) return floor
//neither is valid, not found
return null
}
public fun builder(): UnivariateHistogramBuilder = UnivariateHistogramBuilder(this)
public fun produce(builder: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram = builder().apply(builder) private class UnivariateHistogramImpl(
override val context: UnivariateHistogramSpace,
val binMap: TreeMap<Double, out UnivariateBin>,
) : UnivariateHistogram {
override fun get(value: Double): UnivariateBin? = binMap.getBin(value)
override val dimension: Int get() = 1
override val bins: Collection<UnivariateBin> get() = binMap.values
}
private class UnivariateBinCounter(
override val def: UnivariateHistogramBinDefinition,
) : UnivariateBin {
val counter: LongCounter = LongCounter()
val valueCounter: DoubleCounter = DoubleCounter()
/**
* The precise number of events ignoring weighting
*/
val count: Long get() = counter.sum()
override val standardDeviation: Double get() = sqrt(count.toDouble()) / count * value
/**
* The value of histogram including weighting
*/
override val value: Double get() = valueCounter.sum()
public fun increment(count: Long, value: Double) {
counter.add(count)
valueCounter.add(value)
}
}
private class UnivariateBinValue(
override val def: UnivariateHistogramBinDefinition,
override val value: Double,
override val standardDeviation: Double,
) : UnivariateBin
public class UnivariateHistogramSpace(
public val binFactory: (Double) -> UnivariateHistogramBinDefinition,
) : Space<UnivariateHistogram> {
private inner class UnivariateHistogramBuilderImpl : UnivariateHistogramBuilder {
val bins: TreeMap<Double, UnivariateBinCounter> = TreeMap()
fun get(value: Double): UnivariateBinCounter? = bins.getBin(value)
private fun createBin(value: Double): UnivariateBinCounter {
val binDefinition = binFactory(value)
val newBin = UnivariateBinCounter(binDefinition)
synchronized(this) { bins[binDefinition.position] = newBin }
return newBin
}
/**
* Thread safe put operation
*/
override fun put(value: Double, weight: Double) {
(get(value) ?: createBin(value)).apply {
increment(1, weight)
}
}
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
put(point[0], weight)
}
/**
* Put several items into a single bin
*/
override fun putMany(value: Double, count: Int, weight: Double) {
(get(value) ?: createBin(value)).apply {
increment(count.toLong(), weight)
}
}
override fun build(): UnivariateHistogram = UnivariateHistogramImpl(this@UnivariateHistogramSpace, bins)
}
public fun builder(): UnivariateHistogramBuilder = UnivariateHistogramBuilderImpl()
public fun produce(builder: UnivariateHistogramBuilder.() -> Unit): UnivariateHistogram =
UnivariateHistogramBuilderImpl().apply(builder).build()
override fun add( override fun add(
a: UnivariateHistogram, a: UnivariateHistogram,
@ -14,12 +111,75 @@ public class UnivariateHistogramSpace(public val binFactory: (Double) -> Univari
): UnivariateHistogram { ): UnivariateHistogram {
require(a.context == this) { "Histogram $a does not belong to this context" } require(a.context == this) { "Histogram $a does not belong to this context" }
require(b.context == this) { "Histogram $b does not belong to this context" } require(b.context == this) { "Histogram $b does not belong to this context" }
TODO() val bins = TreeMap<Double, UnivariateBin>().apply {
(a.bins.map { it.def } union b.bins.map { it.def }).forEach { def ->
val newBin = UnivariateBinValue(
def,
value = (a[def.position]?.value ?: 0.0) + (b[def.position]?.value ?: 0.0),
standardDeviation = (a[def.position]?.standardDeviation
?: 0.0) + (b[def.position]?.standardDeviation ?: 0.0)
)
}
}
return UnivariateHistogramImpl(this, bins)
} }
override fun multiply(a: UnivariateHistogram, k: Number): UnivariateHistogram { override fun multiply(a: UnivariateHistogram, k: Number): UnivariateHistogram {
TODO("Not yet implemented") val bins = TreeMap<Double, UnivariateBin>().apply {
a.bins.forEach { bin ->
put(bin.position,
UnivariateBinValue(
bin.def,
value = bin.value * k.toDouble(),
standardDeviation = abs(bin.standardDeviation * k.toDouble())
)
)
}
}
return UnivariateHistogramImpl(this, bins)
} }
override val zero: UnivariateHistogram = produce { } override val zero: UnivariateHistogram = produce { }
public companion object {
/**
* Build and fill a [UnivariateHistogram]. Returns a read-only histogram.
*/
public fun uniform(
binSize: Double,
start: Double = 0.0
): UnivariateHistogramSpace = UnivariateHistogramSpace { value ->
val center = start + binSize * Math.floor((value - start) / binSize + 0.5)
UnivariateHistogramBinDefinition(center, binSize)
}
/**
* Create a histogram with custom cell borders
*/
public fun custom(borders: DoubleArray): UnivariateHistogramSpace {
val sorted = borders.sortedArray()
return UnivariateHistogramSpace { value ->
when {
value < sorted.first() -> UnivariateHistogramBinDefinition(
Double.NEGATIVE_INFINITY,
Double.MAX_VALUE
)
value > sorted.last() -> UnivariateHistogramBinDefinition(
Double.POSITIVE_INFINITY,
Double.MAX_VALUE
)
else -> {
val index = sorted.indices.first { value > sorted[it] }
val left = sorted[index]
val right = sorted[index + 1]
UnivariateHistogramBinDefinition((left + right) / 2, (right - left))
}
}
}
}
}
} }