Add inc/dec to counters

This commit is contained in:
Alexander Nozik 2021-11-23 10:32:51 +03:00
parent 4e16cf1c98
commit 24799a691f
5 changed files with 28 additions and 7 deletions

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.histogram
import kotlinx.atomicfu.atomic import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate import kotlinx.atomicfu.getAndUpdate
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Group
/** /**
* Common representation for atomic counters * Common representation for atomic counters
@ -18,7 +18,7 @@ public interface Counter<T : Any> {
public val value: T public val value: T
public companion object { public companion object {
public fun real(): ObjectCounter<Double> = ObjectCounter(DoubleField) public fun double(): ObjectCounter<Double> = ObjectCounter(DoubleField)
} }
} }
@ -32,6 +32,16 @@ public class IntCounter : Counter<Int> {
override val value: Int get() = innerValue.value override val value: Int get() = innerValue.value
} }
public operator fun IntCounter.inc(): IntCounter {
add(1)
return this
}
public operator fun IntCounter.dec(): IntCounter {
add(-1)
return this
}
public class LongCounter : Counter<Long> { public class LongCounter : Counter<Long> {
private val innerValue = atomic(0L) private val innerValue = atomic(0L)
@ -42,7 +52,17 @@ public class LongCounter : Counter<Long> {
override val value: Long get() = innerValue.value override val value: Long get() = innerValue.value
} }
public class ObjectCounter<T : Any>(public val group: Ring<T>) : Counter<T> { public operator fun LongCounter.inc(): LongCounter {
add(1L)
return this
}
public operator fun LongCounter.dec(): LongCounter {
add(-1L)
return this
}
public class ObjectCounter<T : Any>(private val group: Group<T>) : Counter<T> {
private val innerValue = atomic(group.zero) private val innerValue = atomic(group.zero)
override fun add(delta: T) { override fun add(delta: T) {

View File

@ -74,7 +74,7 @@ public class DoubleHistogramSpace(
} }
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> { override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
val ndCounter = StructureND.auto(shape) { Counter.real() } val ndCounter = StructureND.auto(shape) { Counter.double() }
val hBuilder = HistogramBuilder<Double> { point, value -> val hBuilder = HistogramBuilder<Double> { point, value ->
val index = getIndex(point) val index = getIndex(point)
ndCounter[index].add(value.toDouble()) ndCounter[index].add(value.toDouble())

View File

@ -41,7 +41,6 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
get() = DefaultStrides(context.shape).asSequence().map { get() = DefaultStrides(context.shape).asSequence().map {
context.produceBin(it, values[it]) context.produceBin(it, values[it])
}.asIterable() }.asIterable()
} }
/** /**

View File

@ -39,7 +39,7 @@ public class TreeHistogram(
@PublishedApi @PublishedApi
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder { internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) : internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.double()) :
ClosedFloatingPointRange<Double> by domain.range ClosedFloatingPointRange<Double> by domain.range
private val bins: TreeMap<Double, BinCounter> = TreeMap() private val bins: TreeMap<Double, BinCounter> = TreeMap()

View File

@ -233,3 +233,5 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
override fun export(arg: StructureND<T>): StructureND<T> = override fun export(arg: StructureND<T>): StructureND<T> =
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
} }
//TODO add TensorFlow expressions