Faster implementation for FastHistogram

This commit is contained in:
Alexander Nozik 2018-11-24 20:55:35 +03:00
parent 15cc4a22e2
commit d799668083
5 changed files with 77 additions and 57 deletions

View File

@ -1,44 +1,48 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.toVector import scientifik.kmath.linear.toVector
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.*
import scientifik.kmath.structures.ListBuffer
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.ndStructure
import kotlin.math.floor import kotlin.math.floor
typealias RealPoint = Point<Double>
private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] }) private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] })
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = (0 until size).asSequence().map { mapper(it, get(it)) } private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = (0 until size).asSequence().map { mapper(it, get(it)) }
class MultivariateBin(override val center: RealPoint, val sizes: RealPoint, var counter: Long = 0) : Bin<Double> { //class MultivariateBin(override val center: RealPoint, val sizes: RealPoint, var counter: Long = 0) : Bin<Double> {
init { // init {
if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}") // if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}")
} // }
//
override fun contains(vector: Buffer<out Double>): Boolean { // override fun contains(vector: Buffer<out Double>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") // if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
return vector.mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it } // return vector.mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it }
} // }
//
override val value get() = counter // override val value get() = counter
internal operator fun inc() = this.also { counter++ } // internal operator fun inc() = this.also { counter++ }
//
override val dimension: Int get() = center.size // override val dimension: Int get() = center.size
} //}
/** /**
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
* The histogram is optimized for speed, but have large size in memory
*/ */
class FastHistogram( class FastHistogram(
private val lower: RealPoint, private val lower: RealPoint,
private val upper: RealPoint, private val upper: RealPoint,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 }
) : MutableHistogram<Double, MultivariateBin> { ) : MutableHistogram<Double, PhantomBin<Double>> {
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
private val values: NDStructure<LongCounter> = ndStructure(strides) { LongCounter() }
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
//TODO optimize binSize performance if needed
private val binSize: RealPoint = ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
init { init {
// argument checks // argument checks
@ -50,24 +54,6 @@ class FastHistogram(
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
//TODO optimize binSize performance if needed
private val binSize: RealPoint = ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
private val bins: NDStructure<MultivariateBin> by lazy {
val actualSizes = IntArray(binNums.size) { binNums[it] + 2 }
ndStructure(actualSizes) { indexArray ->
val center = ListBuffer(
indexArray.mapIndexed { axis, index ->
when (index) {
0 -> Double.NEGATIVE_INFINITY
actualSizes[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (index.toDouble() - 0.5) * binSize[axis]
}
}
)
MultivariateBin(center, binSize)
}
}
/** /**
* Get internal [NDStructure] bin index for given axis * Get internal [NDStructure] bin index for given axis
@ -80,24 +66,51 @@ class FastHistogram(
} }
} }
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
override fun get(point: Buffer<out Double>): MultivariateBin? { private fun getValue(index: IntArray): Long {
val index = IntArray(dimension) { getIndex(it, point[it]) } return values[index].sum()
return bins[index] }
fun getValue(point: Buffer<out Double>): Long {
return getValue(getIndex(point))
}
private fun getTemplate(index: IntArray): BinTemplate<Double> {
val center = index.mapIndexed { axis, i ->
when (i) {
0 -> Double.NEGATIVE_INFINITY
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
}
}.toVector()
return BinTemplate(center, binSize)
}
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> {
return getTemplate(getIndex(point))
}
override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
val index = getIndex(point)
return PhantomBin(getTemplate(index), getValue(index))
} }
override fun put(point: Buffer<out Double>, weight: Double) { override fun put(point: Buffer<out Double>, weight: Double) {
if (weight != 1.0) TODO("Implement weighting") if (weight != 1.0) TODO("Implement weighting")
this[point]?.inc() ?: error("Could not find appropriate bin (should not be possible)") val index = getIndex(point)
values[index].increment()
} }
override fun iterator(): Iterator<MultivariateBin> = bins.asSequence().map { it.second }.iterator() override fun iterator(): Iterator<PhantomBin<Double>> = values.asSequence().map { (index, value) ->
PhantomBin(getTemplate(index), value.sum())
}.iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun asND(): NDStructure<Number> { fun asND(): NDStructure<Number> {
return ndStructure(this.bins.shape) { bins[it].value } return ndStructure(this.values.shape) { values[it].sum() }
} }
// /** // /**

View File

@ -6,6 +6,8 @@ import scientifik.kmath.structures.DoubleBuffer
typealias Point<T> = Buffer<T> typealias Point<T> = Buffer<T>
typealias RealPoint = Buffer<Double>
/** /**
* A simple geometric domain * A simple geometric domain
* TODO move to geometry module * TODO move to geometry module

View File

@ -3,12 +3,13 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.Vector import scientifik.kmath.linear.Vector
import scientifik.kmath.operations.Space import scientifik.kmath.operations.Space
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.asSequence
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Vector<T, *>) { data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
fun contains(vector: Point<out T>): Boolean { fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
val upper = center + sizes / 2.0 val upper = center.context.run { center + sizes / 2.0}
val lower = center - sizes / 2.0 val lower = center.context.run {center - sizes / 2.0}
return vector.asSequence().mapIndexed { i, value -> return vector.asSequence().mapIndexed { i, value ->
value in lower[i]..upper[i] value in lower[i]..upper[i]
}.all { it } }.all { it }

View File

@ -39,18 +39,24 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T,
/** /**
* Create vector with custom field * Create vector with custom field
*/ */
fun <T : Any, F: Field<T>> of(size: Int, field: F, initializer: (Int) -> T) = fun <T : Any, F : Field<T>> of(size: Int, field: F, initializer: (Int) -> T) =
ArrayVector(ArrayVectorSpace(size, field), initializer) ArrayVector(ArrayVectorSpace(size, field), initializer)
private val realSpaceCache = HashMap<Int, ArrayVectorSpace<Double, DoubleField>>()
private fun getRealSpace(size: Int): ArrayVectorSpace<Double, DoubleField> {
return realSpaceCache.getOrPut(size){ArrayVectorSpace(size, DoubleField, realNDFieldFactory)}
}
/** /**
* Create vector of [Double] * Create vector of [Double]
*/ */
fun ofReal(size: Int, initializer: (Int) -> Double) = fun ofReal(size: Int, initializer: (Int) -> Double) =
ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer) ArrayVector(getRealSpace(size), initializer)
fun ofReal(vararg point: Double) = point.toVector() fun ofReal(vararg point: Double) = point.toVector()
fun equals(v1: Vector<*,*>, v2: Vector<*,*>): Boolean { fun equals(v1: Vector<*, *>, v2: Vector<*, *>): Boolean {
if (v1 === v2) return true if (v1 === v2) return true
if (v1.context != v2.context) return false if (v1.context != v2.context) return false
for (i in 0 until v2.size) { for (i in 0 until v2.size) {
@ -74,8 +80,6 @@ class ArrayVectorSpace<T : Any, F : Field<T>>(
} }
class ArrayVector<T : Any, F : Field<T>> internal constructor(override val context: VectorSpace<T, F>, val element: NDElement<T, F>) : Vector<T, F> { class ArrayVector<T : Any, F : Field<T>> internal constructor(override val context: VectorSpace<T, F>, val element: NDElement<T, F>) : Vector<T, F> {
constructor(context: ArrayVectorSpace<T, F>, initializer: (Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0]) }) constructor(context: ArrayVectorSpace<T, F>, initializer: (Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0]) })

View File

@ -12,14 +12,14 @@ interface Buffer<T> {
operator fun iterator(): Iterator<T> operator fun iterator(): Iterator<T>
fun asSequence(): Sequence<T> = iterator().asSequence()
/** /**
* A shallow copy of the buffer * A shallow copy of the buffer
*/ */
fun copy(): Buffer<T> fun copy(): Buffer<T>
} }
fun <T> Buffer<T>.asSequence(): Sequence<T> = iterator().asSequence()
interface MutableBuffer<T> : Buffer<T> { interface MutableBuffer<T> : Buffer<T> {
operator fun set(index: Int, value: T) operator fun set(index: Int, value: T)