Final fixes for FastHistogram and NDStructure performance

This commit is contained in:
Alexander Nozik 2018-11-24 21:49:42 +03:00
parent d799668083
commit bdd9dccd4f
3 changed files with 20 additions and 37 deletions

View File

@ -8,23 +8,6 @@ private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until siz
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = (0 until size).asSequence().map { mapper(it, get(it)) } private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = (0 until size).asSequence().map { mapper(it, get(it)) }
//class MultivariateBin(override val center: RealPoint, val sizes: RealPoint, var counter: Long = 0) : Bin<Double> {
// init {
// if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}")
// }
//
// override fun contains(vector: Buffer<out Double>): Boolean {
// if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
// return vector.mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it }
// }
//
// override val value get() = counter
// internal operator fun inc() = this.also { counter++ }
//
// override val dimension: Int get() = center.size
//}
/** /**
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
*/ */
@ -109,18 +92,17 @@ class FastHistogram(
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun asND(): NDStructure<Number> { fun asNDStructure(): NDStructure<Number> {
return ndStructure(this.values.shape) { values[it].sum() } return ndStructure(this.values.shape) { values[it].sum() }
} }
// /** /**
// * Create a phantom lightweight immutable copy of this histogram * Create a phantom lightweight immutable copy of this histogram
// */ */
// fun asPhantom(): PhantomHistogram<Double> { fun asPhantomHistogram(): PhantomHistogram<Double> {
// val center = val binTemplates = values.associate { (index, _) -> getTemplate(index) to index }
// val binTemplates = bins.associate { (index, bin) -> BinTemplate<Double>(bin.center, bin.sizes) to index } return PhantomHistogram(binTemplates, asNDStructure())
// return PhantomHistogram(binTemplates, asND()) }
// }
companion object { companion object {
@ -148,8 +130,8 @@ class FastHistogram(
*/ */
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram { fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram {
return FastHistogram( return FastHistogram(
ranges.map { it.first.start }.toVector(), ListBuffer(ranges.map { it.first.start }),
ranges.map { it.first.endInclusive }.toVector(), ListBuffer(ranges.map { it.first.endInclusive }),
ranges.map { it.second }.toIntArray() ranges.map { it.second }.toIntArray()
) )
} }

View File

@ -53,6 +53,7 @@ interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{
fun <T: Any> MutableHistogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point)) fun <T: Any> MutableHistogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point))
fun MutableHistogram<Double,*>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) fun MutableHistogram<Double,*>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double,*>.put(vararg point: Double) = put(DoubleBuffer(point))
fun <T: Any> MutableHistogram<T,*>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) } fun <T: Any> MutableHistogram<T,*>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }

View File

@ -83,15 +83,15 @@ class DefaultStrides(override val shape: IntArray) : Strides {
} }
override fun index(offset: Int): IntArray { override fun index(offset: Int): IntArray {
return sequence { val res = IntArray(shape.size)
var current = offset var current = offset
var strideIndex = strides.size - 2 var strideIndex = strides.size - 2
while (strideIndex >= 0) { while (strideIndex >= 0) {
yield(current / strides[strideIndex]) res[ strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex] current %= strides[strideIndex]
strideIndex-- strideIndex--
} }
}.toList().reversed().toIntArray() return res
} }
override val linearSize: Int override val linearSize: Int