Added limited polimorphism to Histogram

This commit is contained in:
Alexander Nozik 2018-10-28 21:53:40 +03:00
parent 6a0ef6a235
commit 881ee9f688
5 changed files with 32 additions and 27 deletions

View File

@ -2,6 +2,7 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector
import scientifik.kmath.linear.toVector
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.ndStructure
import kotlin.math.floor
@ -11,7 +12,7 @@ class MultivariateBin(override val center: RealVector, val sizes: RealVector, va
if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}")
}
override fun contains(vector: RealVector): Boolean {
override fun contains(vector: Buffer<out Double>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
return vector.asSequence().mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it }
}
@ -70,12 +71,12 @@ class FastHistogram(
}
override fun get(point: RealVector): MultivariateBin? {
override fun get(point: Buffer<out Double>): MultivariateBin? {
val index = IntArray(dimension) { getIndex(it, point[it]) }
return bins[index]
}
override fun put(point: RealVector) {
override fun put(point: Buffer<out Double>) {
this[point]?.inc() ?: error("Could not find appropriate bin (should not be possible)")
}

View File

@ -1,15 +1,15 @@
package scientifik.kmath.histogram
import scientifik.kmath.linear.Vector
import scientifik.kmath.linear.toVector
import scientifik.kmath.operations.Space
import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.Buffer
/**
* A simple geometric domain
* TODO move to geometry module
*/
interface Domain<T: Any> {
operator fun contains(vector: Vector<T>): Boolean
operator fun contains(vector: Buffer<out T>): Boolean
val dimension: Int
}
@ -21,7 +21,7 @@ interface Bin<T: Any> : Domain<T> {
* The value of this bin
*/
val value: Number
val center: Vector<T>
val center: Buffer<T>
}
interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
@ -29,7 +29,7 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
/**
* Find existing bin, corresponding to given coordinates
*/
operator fun get(point: Vector<T>): B?
operator fun get(point: Buffer<out T>): B?
/**
* Dimension of the histogram
@ -39,17 +39,17 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
/**
* Increment appropriate bin
*/
fun put(point: Vector<T>)
fun put(point: Buffer<out T>)
}
fun Histogram<Double,*>.put(vararg point: Double) = put(point.toVector())
fun <T: Any> Histogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point))
fun <T: Any> Histogram<T,*>.fill(sequence: Iterable<Vector<T>>) = sequence.forEach { put(it) }
fun <T: Any> Histogram<T,*>.fill(sequence: Iterable<Buffer<T>>) = sequence.forEach { put(it) }
/**
* Pass a sequence builder into histogram
*/
fun <T: Any> Histogram<T, *>.fill(buider: suspend SequenceScope<Vector<T>>.() -> Unit) = fill(sequence(buider).asIterable())
fun <T: Any> Histogram<T, *>.fill(buider: suspend SequenceScope<Buffer<T>>.() -> Unit) = fill(sequence(buider).asIterable())
/**
* A space to perform arithmetic operations on histograms

View File

@ -4,10 +4,7 @@ import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.GenericNDField
import scientifik.kmath.structures.NDArray
import scientifik.kmath.structures.NDField
import scientifik.kmath.structures.get
import scientifik.kmath.structures.*
/**
* The space for linear elements. Supports scalar product alongside with standard linear operations.
@ -162,10 +159,8 @@ abstract class VectorSpace<T : Any>(val size: Int, val field: Field<T>) : Space<
}
interface Vector<T : Any> : SpaceElement<Vector<T>, VectorSpace<T>>, Iterable<T> {
val size: Int get() = context.size
operator fun get(i: Int): T
interface Vector<T : Any> : SpaceElement<Vector<T>, VectorSpace<T>>, Buffer<T>, Iterable<T> {
override val size: Int get() = context.size
companion object {
/**
@ -261,14 +256,16 @@ class ArrayVector<T : Any> internal constructor(override val context: ArrayVecto
}
}
override fun get(i: Int): T {
return array[i]
override fun get(index: Int): T {
return array[index]
}
override val self: ArrayVector<T> get() = this
override fun iterator(): Iterator<T> = (0 until size).map { array[it] }.iterator()
override fun copy(): ArrayVector<T> = ArrayVector(context, array)
override fun toString(): String = this.joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() }
}

View File

@ -4,7 +4,7 @@ package scientifik.kmath.structures
/**
* A generic linear buffer for both primitives and objects
*/
interface Buffer<T> {
interface Buffer<T> : Iterable<T> {
val size: Int
@ -36,6 +36,8 @@ inline class ListBuffer<T>(private val list: MutableList<T>) : MutableBuffer<T>
list[index] = value
}
override fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = ListBuffer(ArrayList(list))
}
@ -49,6 +51,8 @@ class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {
array[index] = value
}
override fun iterator(): Iterator<T> = array.iterator()
override fun copy(): MutableBuffer<T> = ArrayBuffer(array.copyOf())
}
@ -62,6 +66,8 @@ class DoubleBuffer(private val array: DoubleArray) : MutableBuffer<Double> {
array[index] = value
}
override fun iterator(): Iterator<Double> = array.iterator()
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
}

View File

@ -2,6 +2,7 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector
import scientifik.kmath.linear.toVector
import scientifik.kmath.structures.Buffer
import java.util.*
import kotlin.math.floor
@ -15,7 +16,7 @@ class UnivariateBin(val position: Double, val size: Double, val counter: LongCou
operator fun contains(value: Double): Boolean = value in (position - size / 2)..(position + size / 2)
override fun contains(vector: RealVector): Boolean = contains(vector[0])
override fun contains(vector: Buffer<out Double>): Boolean = contains(vector[0])
internal operator fun inc() = this.also { counter.increment()}
@ -44,7 +45,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
synchronized(this) { bins.put(it.position, it) }
}
override fun get(point: RealVector): UnivariateBin? = get(point[0])
override fun get(point: Buffer<out Double>): UnivariateBin? = get(point[0])
override val dimension: Int get() = 1
@ -57,7 +58,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
(get(value) ?: createBin(value)).inc()
}
override fun put(point: RealVector) = put(point[0])
override fun put(point: Buffer<out Double>) = put(point[0])
companion object {
fun uniform(binSize: Double, start: Double = 0.0): UnivariateHistogram {