RealHistogram counts weights

This commit is contained in:
Alexander Nozik 2020-04-30 11:20:28 +03:00
parent 9c21f69ec0
commit 2e057c409b
2 changed files with 32 additions and 6 deletions

View File

@ -130,6 +130,21 @@ fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
override fun fork(): Chain<R> = this@map.fork().map(func) override fun fork(): Chain<R> = this@map.fork().map(func)
} }
/**
* [block] must be a pure function or at least not use external random variables, otherwise fork could be broken
*/
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
override suspend fun next(): T {
var next: T
do {
next = this@filter.next()
} while (!block(next))
return next
}
override fun fork(): Chain<T> = this@filter.fork().filter(block)
}
/** /**
* Map the whole chain * Map the whole chain
*/ */

View File

@ -45,8 +45,7 @@ class RealHistogram(
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() } private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null} private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
@ -102,21 +101,33 @@ class RealHistogram(
return MultivariateBin(getDef(index), getValue(index)) return MultivariateBin(getDef(index), getValue(index))
} }
// fun put(point: Point<out Double>){
// val index = getIndex(point)
// values[index].increment()
// }
override fun putWithWeight(point: Buffer<out Double>, weight: Double) { override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
if (weight != 1.0) TODO("Implement weighting")
val index = getIndex(point) val index = getIndex(point)
values[index].increment() values[index].increment()
weights[index].add(weight)
} }
override fun iterator(): Iterator<MultivariateBin<Double>> = values.elements().map { (index, value) -> override fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) ->
MultivariateBin(getDef(index), value.sum()) MultivariateBin(getDef(index), value.sum())
}.iterator() }.iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun asNDStructure(): NDStructure<Number> { fun values(): NDStructure<Number> {
return NDStructure.auto(this.values.shape) { values[it].sum() } return NDStructure.auto(values.shape) { values[it].sum() }
}
/**
* Sum of weights
*/
fun weights():NDStructure<Double>{
return NDStructure.auto(weights.shape) { weights[it].sum() }
} }
companion object { companion object {