Added limited polimorphism to Histogram

This commit is contained in:
Alexander Nozik 2018-10-28 21:53:40 +03:00
parent 6a0ef6a235
commit 881ee9f688
5 changed files with 32 additions and 27 deletions

View File

@ -2,6 +2,7 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector import scientifik.kmath.linear.RealVector
import scientifik.kmath.linear.toVector import scientifik.kmath.linear.toVector
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.ndStructure import scientifik.kmath.structures.ndStructure
import kotlin.math.floor import kotlin.math.floor
@ -11,7 +12,7 @@ class MultivariateBin(override val center: RealVector, val sizes: RealVector, va
if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}") if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}")
} }
override fun contains(vector: RealVector): Boolean { override fun contains(vector: Buffer<out Double>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
return vector.asSequence().mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it } return vector.asSequence().mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it }
} }
@ -70,12 +71,12 @@ class FastHistogram(
} }
override fun get(point: RealVector): MultivariateBin? { override fun get(point: Buffer<out Double>): MultivariateBin? {
val index = IntArray(dimension) { getIndex(it, point[it]) } val index = IntArray(dimension) { getIndex(it, point[it]) }
return bins[index] return bins[index]
} }
override fun put(point: RealVector) { override fun put(point: Buffer<out Double>) {
this[point]?.inc() ?: error("Could not find appropriate bin (should not be possible)") this[point]?.inc() ?: error("Could not find appropriate bin (should not be possible)")
} }

View File

@ -1,15 +1,15 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Vector
import scientifik.kmath.linear.toVector
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.Buffer
/** /**
* A simple geometric domain * A simple geometric domain
* TODO move to geometry module * TODO move to geometry module
*/ */
interface Domain<T: Any> { interface Domain<T: Any> {
operator fun contains(vector: Vector<T>): Boolean operator fun contains(vector: Buffer<out T>): Boolean
val dimension: Int val dimension: Int
} }
@ -21,7 +21,7 @@ interface Bin<T: Any> : Domain<T> {
* The value of this bin * The value of this bin
*/ */
val value: Number val value: Number
val center: Vector<T> val center: Buffer<T>
} }
interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> { interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
@ -29,7 +29,7 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
/** /**
* Find existing bin, corresponding to given coordinates * Find existing bin, corresponding to given coordinates
*/ */
operator fun get(point: Vector<T>): B? operator fun get(point: Buffer<out T>): B?
/** /**
* Dimension of the histogram * Dimension of the histogram
@ -39,17 +39,17 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
/** /**
* Increment appropriate bin * Increment appropriate bin
*/ */
fun put(point: Vector<T>) fun put(point: Buffer<out T>)
} }
fun Histogram<Double,*>.put(vararg point: Double) = put(point.toVector()) fun <T: Any> Histogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point))
fun <T: Any> Histogram<T,*>.fill(sequence: Iterable<Vector<T>>) = sequence.forEach { put(it) } fun <T: Any> Histogram<T,*>.fill(sequence: Iterable<Buffer<T>>) = sequence.forEach { put(it) }
/** /**
* Pass a sequence builder into histogram * Pass a sequence builder into histogram
*/ */
fun <T: Any> Histogram<T, *>.fill(buider: suspend SequenceScope<Vector<T>>.() -> Unit) = fill(sequence(buider).asIterable()) fun <T: Any> Histogram<T, *>.fill(buider: suspend SequenceScope<Buffer<T>>.() -> Unit) = fill(sequence(buider).asIterable())
/** /**
* A space to perform arithmetic operations on histograms * A space to perform arithmetic operations on histograms

View File

@ -4,10 +4,7 @@ import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.GenericNDField import scientifik.kmath.structures.*
import scientifik.kmath.structures.NDArray
import scientifik.kmath.structures.NDField
import scientifik.kmath.structures.get
/** /**
* The space for linear elements. Supports scalar product alongside with standard linear operations. * The space for linear elements. Supports scalar product alongside with standard linear operations.
@ -162,10 +159,8 @@ abstract class VectorSpace<T : Any>(val size: Int, val field: Field<T>) : Space<
} }
interface Vector<T : Any> : SpaceElement<Vector<T>, VectorSpace<T>>, Iterable<T> { interface Vector<T : Any> : SpaceElement<Vector<T>, VectorSpace<T>>, Buffer<T>, Iterable<T> {
val size: Int get() = context.size override val size: Int get() = context.size
operator fun get(i: Int): T
companion object { companion object {
/** /**
@ -261,14 +256,16 @@ class ArrayVector<T : Any> internal constructor(override val context: ArrayVecto
} }
} }
override fun get(i: Int): T { override fun get(index: Int): T {
return array[i] return array[index]
} }
override val self: ArrayVector<T> get() = this override val self: ArrayVector<T> get() = this
override fun iterator(): Iterator<T> = (0 until size).map { array[it] }.iterator() override fun iterator(): Iterator<T> = (0 until size).map { array[it] }.iterator()
override fun copy(): ArrayVector<T> = ArrayVector(context, array)
override fun toString(): String = this.joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() } override fun toString(): String = this.joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() }
} }

View File

@ -4,7 +4,7 @@ package scientifik.kmath.structures
/** /**
* A generic linear buffer for both primitives and objects * A generic linear buffer for both primitives and objects
*/ */
interface Buffer<T> { interface Buffer<T> : Iterable<T> {
val size: Int val size: Int
@ -36,6 +36,8 @@ inline class ListBuffer<T>(private val list: MutableList<T>) : MutableBuffer<T>
list[index] = value list[index] = value
} }
override fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = ListBuffer(ArrayList(list)) override fun copy(): MutableBuffer<T> = ListBuffer(ArrayList(list))
} }
@ -49,6 +51,8 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
array[index] = value array[index] = value
} }
override fun iterator(): Iterator<T> = array.iterator()
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf()) override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
} }
@ -62,6 +66,8 @@ class DoubleBuffer(private val array: DoubleArray) : MutableBuffer<Double> {
array[index] = value array[index] = value
} }
override fun iterator(): Iterator<Double> = array.iterator()
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf()) override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
} }

View File

@ -2,6 +2,7 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector import scientifik.kmath.linear.RealVector
import scientifik.kmath.linear.toVector import scientifik.kmath.linear.toVector
import scientifik.kmath.structures.Buffer
import java.util.* import java.util.*
import kotlin.math.floor import kotlin.math.floor
@ -15,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2) operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
override fun contains(vector: RealVector): Boolean = contains(vector[0]) override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0])
internal operator fun inc() = this.also { counter.increment()} internal operator fun inc() = this.also { counter.increment()}
@ -44,7 +45,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
synchronized(this) { bins.put(it.position, it) } synchronized(this) { bins.put(it.position, it) }
} }
override fun get(point: RealVector): UnivariateBin? = get(point[0]) override fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
override val dimension: Int get() = 1 override val dimension: Int get() = 1
@ -57,7 +58,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
(get(value) ?: createBin(value)).inc() (get(value) ?: createBin(value)).inc()
} }
override fun put(point: RealVector) = put(point[0]) override fun put(point: Buffer<out Double>) = put(point[0])
companion object { companion object {
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram { fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram {