Add inc/dec to counters

This commit is contained in:
Alexander Nozik 2021-11-23 10:32:51 +03:00
parent 4e16cf1c98
commit 24799a691f
5 changed files with 28 additions and 7 deletions

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.histogram
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.Group
/**
* Common representation for atomic counters
@ -18,7 +18,7 @@ public interface Counter<T : Any> {
public val value: T
public companion object {
public fun real(): ObjectCounter<Double> = ObjectCounter(DoubleField)
public fun double(): ObjectCounter<Double> = ObjectCounter(DoubleField)
}
}
@ -32,6 +32,16 @@ public class IntCounter : Counter<Int> {
override val value: Int get() = innerValue.value
}
public operator fun IntCounter.inc(): IntCounter {
add(1)
return this
}
public operator fun IntCounter.dec(): IntCounter {
add(-1)
return this
}
public class LongCounter : Counter<Long> {
private val innerValue = atomic(0L)
@ -42,7 +52,17 @@ public class LongCounter : Counter<Long> {
override val value: Long get() = innerValue.value
}
public class ObjectCounter<T : Any>(public val group: Ring<T>) : Counter<T> {
public operator fun LongCounter.inc(): LongCounter {
add(1L)
return this
}
public operator fun LongCounter.dec(): LongCounter {
add(-1L)
return this
}
public class ObjectCounter<T : Any>(private val group: Group<T>) : Counter<T> {
private val innerValue = atomic(group.zero)
override fun add(delta: T) {

View File

@ -74,7 +74,7 @@ public class DoubleHistogramSpace(
}
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
val ndCounter = StructureND.auto(shape) { Counter.real() }
val ndCounter = StructureND.auto(shape) { Counter.double() }
val hBuilder = HistogramBuilder<Double> { point, value ->
val index = getIndex(point)
ndCounter[index].add(value.toDouble())

View File

@ -41,7 +41,6 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
get() = DefaultStrides(context.shape).asSequence().map {
context.produceBin(it, values[it])
}.asIterable()
}
/**

View File

@ -39,7 +39,7 @@ public class TreeHistogram(
@PublishedApi
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) :
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.double()) :
ClosedFloatingPointRange<Double> by domain.range
private val bins: TreeMap<Double, BinCounter> = TreeMap()

View File

@ -232,4 +232,6 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
@OptIn(UnstableKMathAPI::class)
override fun export(arg: StructureND<T>): StructureND<T> =
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
}
}
//TODO add TensorFlow expressions