From 2e057c409b29f6a2c7d125169d9d315dca794a03 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 30 Apr 2020 11:20:28 +0300 Subject: [PATCH] RealHistogram counts weights --- .../kotlin/scientifik/kmath/chains/Chain.kt | 15 ++++++++++++ .../kmath/histogram/RealHistogram.kt | 23 ++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt index 467bbe978..1c2872d17 100644 --- a/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt +++ b/kmath-coroutines/src/commonMain/kotlin/scientifik/kmath/chains/Chain.kt @@ -130,6 +130,21 @@ fun Chain.map(func: suspend (T) -> R): Chain = object : Chain { override fun fork(): Chain = this@map.fork().map(func) } +/** + * [block] must be a pure function or at least not use external random variables, otherwise fork could be broken + */ +fun Chain.filter(block: (T) -> Boolean): Chain = object : Chain { + override suspend fun next(): T { + var next: T + do { + next = this@filter.next() + } while (!block(next)) + return next + } + + override fun fork(): Chain = this@filter.fork().filter(block) +} + /** * Map the whole chain */ diff --git a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt index a2f0fc516..85f078fda 100644 --- a/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt +++ b/kmath-histograms/src/commonMain/kotlin/scientifik/kmath/histogram/RealHistogram.kt @@ -45,8 +45,7 @@ class RealHistogram( private val values: NDStructure = NDStructure.auto(strides) { LongCounter() } - //private val weight: NDStructure = ndStructure(strides){null} - + private val weights: NDStructure = NDStructure.auto(strides) { DoubleCounter() } override val dimension: Int get() = lower.size @@ -102,21 +101,33 @@ class RealHistogram( return MultivariateBin(getDef(index), getValue(index)) } +// fun put(point: Point){ +// val index = getIndex(point) +// values[index].increment() +// } + override fun putWithWeight(point: Buffer, weight: Double) { - if (weight != 1.0) TODO("Implement weighting") val index = getIndex(point) values[index].increment() + weights[index].add(weight) } - override fun iterator(): Iterator> = values.elements().map { (index, value) -> + override fun iterator(): Iterator> = weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }.iterator() /** * Convert this histogram into NDStructure containing bin values but not bin descriptions */ - fun asNDStructure(): NDStructure { - return NDStructure.auto(this.values.shape) { values[it].sum() } + fun values(): NDStructure { + return NDStructure.auto(values.shape) { values[it].sum() } + } + + /** + * Sum of weights + */ + fun weights():NDStructure{ + return NDStructure.auto(weights.shape) { weights[it].sum() } } companion object {