forked from kscience/kmath
Add inc/dec to counters
This commit is contained in:
parent
4e16cf1c98
commit
24799a691f
@ -8,7 +8,7 @@ package space.kscience.kmath.histogram
|
|||||||
import kotlinx.atomicfu.atomic
|
import kotlinx.atomicfu.atomic
|
||||||
import kotlinx.atomicfu.getAndUpdate
|
import kotlinx.atomicfu.getAndUpdate
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Group
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Common representation for atomic counters
|
* Common representation for atomic counters
|
||||||
@ -18,7 +18,7 @@ public interface Counter<T : Any> {
|
|||||||
public val value: T
|
public val value: T
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun real(): ObjectCounter<Double> = ObjectCounter(DoubleField)
|
public fun double(): ObjectCounter<Double> = ObjectCounter(DoubleField)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -32,6 +32,16 @@ public class IntCounter : Counter<Int> {
|
|||||||
override val value: Int get() = innerValue.value
|
override val value: Int get() = innerValue.value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public operator fun IntCounter.inc(): IntCounter {
|
||||||
|
add(1)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun IntCounter.dec(): IntCounter {
|
||||||
|
add(-1)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
public class LongCounter : Counter<Long> {
|
public class LongCounter : Counter<Long> {
|
||||||
private val innerValue = atomic(0L)
|
private val innerValue = atomic(0L)
|
||||||
|
|
||||||
@ -42,7 +52,17 @@ public class LongCounter : Counter<Long> {
|
|||||||
override val value: Long get() = innerValue.value
|
override val value: Long get() = innerValue.value
|
||||||
}
|
}
|
||||||
|
|
||||||
public class ObjectCounter<T : Any>(public val group: Ring<T>) : Counter<T> {
|
public operator fun LongCounter.inc(): LongCounter {
|
||||||
|
add(1L)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun LongCounter.dec(): LongCounter {
|
||||||
|
add(-1L)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ObjectCounter<T : Any>(private val group: Group<T>) : Counter<T> {
|
||||||
private val innerValue = atomic(group.zero)
|
private val innerValue = atomic(group.zero)
|
||||||
|
|
||||||
override fun add(delta: T) {
|
override fun add(delta: T) {
|
||||||
|
@ -74,7 +74,7 @@ public class DoubleHistogramSpace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
||||||
val ndCounter = StructureND.auto(shape) { Counter.real() }
|
val ndCounter = StructureND.auto(shape) { Counter.double() }
|
||||||
val hBuilder = HistogramBuilder<Double> { point, value ->
|
val hBuilder = HistogramBuilder<Double> { point, value ->
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
ndCounter[index].add(value.toDouble())
|
ndCounter[index].add(value.toDouble())
|
||||||
|
@ -41,7 +41,6 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
|||||||
get() = DefaultStrides(context.shape).asSequence().map {
|
get() = DefaultStrides(context.shape).asSequence().map {
|
||||||
context.produceBin(it, values[it])
|
context.produceBin(it, values[it])
|
||||||
}.asIterable()
|
}.asIterable()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -39,7 +39,7 @@ public class TreeHistogram(
|
|||||||
@PublishedApi
|
@PublishedApi
|
||||||
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
|
internal class TreeHistogramBuilder(val binFactory: (Double) -> UnivariateDomain) : UnivariateHistogramBuilder {
|
||||||
|
|
||||||
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.real()) :
|
internal class BinCounter(val domain: UnivariateDomain, val counter: Counter<Double> = Counter.double()) :
|
||||||
ClosedFloatingPointRange<Double> by domain.range
|
ClosedFloatingPointRange<Double> by domain.range
|
||||||
|
|
||||||
private val bins: TreeMap<Double, BinCounter> = TreeMap()
|
private val bins: TreeMap<Double, BinCounter> = TreeMap()
|
||||||
|
@ -233,3 +233,5 @@ public abstract class TensorFlowAlgebra<T, TT : TType, A : Ring<T>> internal con
|
|||||||
override fun export(arg: StructureND<T>): StructureND<T> =
|
override fun export(arg: StructureND<T>): StructureND<T> =
|
||||||
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
if (arg is TensorFlowOutput<T, *>) arg.actualTensor else arg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO add TensorFlow expressions
|
Loading…
Reference in New Issue
Block a user