RealHistogram counts weights
This commit is contained in:
parent
9c21f69ec0
commit
2e057c409b
@ -130,6 +130,21 @@ fun <T, R> Chain<T>.map(func: suspend (T) -> R): Chain<R> = object : Chain<R> {
|
|||||||
override fun fork(): Chain<R> = this@map.fork().map(func)
|
override fun fork(): Chain<R> = this@map.fork().map(func)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* [block] must be a pure function or at least not use external random variables, otherwise fork could be broken
|
||||||
|
*/
|
||||||
|
fun <T> Chain<T>.filter(block: (T) -> Boolean): Chain<T> = object : Chain<T> {
|
||||||
|
override suspend fun next(): T {
|
||||||
|
var next: T
|
||||||
|
do {
|
||||||
|
next = this@filter.next()
|
||||||
|
} while (!block(next))
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun fork(): Chain<T> = this@filter.fork().filter(block)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map the whole chain
|
* Map the whole chain
|
||||||
*/
|
*/
|
||||||
|
@ -45,8 +45,7 @@ class RealHistogram(
|
|||||||
|
|
||||||
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
||||||
|
|
||||||
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
|
private val weights: NDStructure<DoubleCounter> = NDStructure.auto(strides) { DoubleCounter() }
|
||||||
|
|
||||||
|
|
||||||
override val dimension: Int get() = lower.size
|
override val dimension: Int get() = lower.size
|
||||||
|
|
||||||
@ -102,21 +101,33 @@ class RealHistogram(
|
|||||||
return MultivariateBin(getDef(index), getValue(index))
|
return MultivariateBin(getDef(index), getValue(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fun put(point: Point<out Double>){
|
||||||
|
// val index = getIndex(point)
|
||||||
|
// values[index].increment()
|
||||||
|
// }
|
||||||
|
|
||||||
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||||
if (weight != 1.0) TODO("Implement weighting")
|
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
values[index].increment()
|
values[index].increment()
|
||||||
|
weights[index].add(weight)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<MultivariateBin<Double>> = values.elements().map { (index, value) ->
|
override fun iterator(): Iterator<MultivariateBin<Double>> = weights.elements().map { (index, value) ->
|
||||||
MultivariateBin(getDef(index), value.sum())
|
MultivariateBin(getDef(index), value.sum())
|
||||||
}.iterator()
|
}.iterator()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
||||||
*/
|
*/
|
||||||
fun asNDStructure(): NDStructure<Number> {
|
fun values(): NDStructure<Number> {
|
||||||
return NDStructure.auto(this.values.shape) { values[it].sum() }
|
return NDStructure.auto(values.shape) { values[it].sum() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sum of weights
|
||||||
|
*/
|
||||||
|
fun weights():NDStructure<Double>{
|
||||||
|
return NDStructure.auto(weights.shape) { weights[it].sum() }
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
Loading…
Reference in New Issue
Block a user