Fixed some performance issues with histogram

This commit is contained in:
Alexander Nozik 2018-11-23 21:44:42 +03:00
parent 5e3e0eb09d
commit 15cc4a22e2
5 changed files with 74 additions and 47 deletions

View File

@ -1,24 +1,31 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector
import scientifik.kmath.linear.toVector import scientifik.kmath.linear.toVector
import scientifik.kmath.structures.Buffer import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.ListBuffer
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.ndStructure import scientifik.kmath.structures.ndStructure
import kotlin.math.floor import kotlin.math.floor
class MultivariateBin(override val center: RealVector, val sizes: RealVector, val counter: LongCounter = LongCounter()) : Bin<Double> { typealias RealPoint = Point<Double>
private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] })
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = (0 until size).asSequence().map { mapper(it, get(it)) }
class MultivariateBin(override val center: RealPoint, val sizes: RealPoint, var counter: Long = 0) : Bin<Double> {
init { init {
if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}") if (center.size != sizes.size) error("Dimension mismatch in bin creation. Expected ${center.size}, but found ${sizes.size}")
} }
override fun contains(vector: Buffer<out Double>): Boolean { override fun contains(vector: Buffer<out Double>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
return vector.asSequence().mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it } return vector.mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it }
} }
override val value get() = counter.sum() override val value get() = counter
internal operator fun inc() = this.also { counter.increment() } internal operator fun inc() = this.also { counter++ }
override val dimension: Int get() = center.size override val dimension: Int get() = center.size
} }
@ -28,8 +35,8 @@ class MultivariateBin(override val center: RealVector, val sizes: RealVector, va
* The histogram is optimized for speed, but have large size in memory * The histogram is optimized for speed, but have large size in memory
*/ */
class FastHistogram( class FastHistogram(
private val lower: RealVector, private val lower: RealPoint,
private val upper: RealVector, private val upper: RealPoint,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 }
) : MutableHistogram<Double, MultivariateBin> { ) : MutableHistogram<Double, MultivariateBin> {
@ -37,25 +44,27 @@ class FastHistogram(
// argument checks // argument checks
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
if (lower.size != binNums.size) error("Dimension mismatch in bin count.") if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
if ((upper - lower).any { it <= 0 }) error("Range for one of axis is not strictly positive") if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive")
} }
override val dimension: Int get() = lower.size override val dimension: Int get() = lower.size
//TODO optimize binSize performance if needed //TODO optimize binSize performance if needed
private val binSize = (upper - lower).mapIndexed { index, value -> value / binNums[index] }.toVector() private val binSize: RealPoint = ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
private val bins: NDStructure<MultivariateBin> by lazy { private val bins: NDStructure<MultivariateBin> by lazy {
val actualSizes = IntArray(binNums.size) { binNums[it] + 2 } val actualSizes = IntArray(binNums.size) { binNums[it] + 2 }
ndStructure(actualSizes) { indexArray -> ndStructure(actualSizes) { indexArray ->
val center = indexArray.mapIndexed { axis, index -> val center = ListBuffer(
when (index) { indexArray.mapIndexed { axis, index ->
0 -> Double.NEGATIVE_INFINITY when (index) {
actualSizes[axis] - 1 -> Double.POSITIVE_INFINITY 0 -> Double.NEGATIVE_INFINITY
else -> lower[axis] + (index.toDouble() - 0.5) * binSize[axis] actualSizes[axis] - 1 -> Double.POSITIVE_INFINITY
} else -> lower[axis] + (index.toDouble() - 0.5) * binSize[axis]
}.toVector() }
}
)
MultivariateBin(center, binSize) MultivariateBin(center, binSize)
} }
} }
@ -91,13 +100,14 @@ class FastHistogram(
return ndStructure(this.bins.shape) { bins[it].value } return ndStructure(this.bins.shape) { bins[it].value }
} }
/** // /**
* Create a phantom lightweight immutable copy of this histogram // * Create a phantom lightweight immutable copy of this histogram
*/ // */
fun asPhantom(): PhantomHistogram<Double> { // fun asPhantom(): PhantomHistogram<Double> {
val binTemplates = bins.associate { (index, bin) -> BinTemplate<Double>(bin.center, bin.sizes) to index } // val center =
return PhantomHistogram(binTemplates, asND()) // val binTemplates = bins.associate { (index, bin) -> BinTemplate<Double>(bin.center, bin.sizes) to index }
} // return PhantomHistogram(binTemplates, asND())
// }
companion object { companion object {

View File

@ -1,9 +1,7 @@
package scientifik.kmath.linear package scientifik.kmath.linear
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
import scientifik.kmath.structures.ExtendedNDField import scientifik.kmath.structures.*
import scientifik.kmath.structures.GenericNDField
import scientifik.kmath.structures.NDField
/** /**
* The space for linear elements. Supports scalar product alongside with standard linear operations. * The space for linear elements. Supports scalar product alongside with standard linear operations.
@ -169,3 +167,21 @@ class ArrayMatrixSpace<T : Any, F : Field<T>>(
return ArrayMatrixSpace(rows, columns, field, ndFactory) return ArrayMatrixSpace(rows, columns, field, ndFactory)
} }
} }
/**
* Member of [ArrayMatrixSpace] which wraps 2-D array
*/
class ArrayMatrix<T : Any, F : Field<T>> internal constructor(override val context: ArrayMatrixSpace<T, F>, val element: NDElement<T, F>) : Matrix<T, F> {
constructor(context: ArrayMatrixSpace<T, F>, initializer: (Int, Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0], list[1]) })
override val rows: Int get() = context.rows
override val columns: Int get() = context.columns
override fun get(i: Int, j: Int): T {
return element[i, j]
}
override val self: ArrayMatrix<T, F> get() = this
}

View File

@ -73,23 +73,7 @@ class ArrayVectorSpace<T : Any, F : Field<T>>(
override fun produce(initializer: (Int) -> T): Vector<T, F> = ArrayVector(this, initializer) override fun produce(initializer: (Int) -> T): Vector<T, F> = ArrayVector(this, initializer)
} }
/**
* Member of [ArrayMatrixSpace] which wraps 2-D array
*/
class ArrayMatrix<T : Any, F : Field<T>> internal constructor(override val context: ArrayMatrixSpace<T, F>, val element: NDElement<T, F>) : Matrix<T, F> {
constructor(context: ArrayMatrixSpace<T, F>, initializer: (Int, Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0], list[1]) })
override val rows: Int get() = context.rows
override val columns: Int get() = context.columns
override fun get(i: Int, j: Int): T {
return element[i, j]
}
override val self: ArrayMatrix<T, F> get() = this
}
class ArrayVector<T : Any, F : Field<T>> internal constructor(override val context: VectorSpace<T, F>, val element: NDElement<T, F>) : Vector<T, F> { class ArrayVector<T : Any, F : Field<T>> internal constructor(override val context: VectorSpace<T, F>, val element: NDElement<T, F>) : Vector<T, F> {

View File

@ -4,12 +4,16 @@ package scientifik.kmath.structures
/** /**
* A generic random access structure for both primitives and objects * A generic random access structure for both primitives and objects
*/ */
interface Buffer<T> : Iterable<T> { interface Buffer<T> {
val size: Int val size: Int
operator fun get(index: Int): T operator fun get(index: Int): T
operator fun iterator(): Iterator<T>
fun asSequence(): Sequence<T> = iterator().asSequence()
/** /**
* A shallow copy of the buffer * A shallow copy of the buffer
*/ */
@ -25,7 +29,20 @@ interface MutableBuffer<T> : Buffer<T> {
override fun copy(): MutableBuffer<T> override fun copy(): MutableBuffer<T>
} }
inline class ListBuffer<T>(private val list: MutableList<T>) : MutableBuffer<T> {
inline class ListBuffer<T>(private val list: List<T>) : Buffer<T> {
override val size: Int
get() = list.size
override fun get(index: Int): T = list[index]
override fun iterator(): Iterator<T> = list.iterator()
override fun copy(): ListBuffer<T> = ListBuffer(ArrayList(list))
}
inline class MutableListBuffer<T>(private val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int override val size: Int
get() = list.size get() = list.size
@ -38,7 +55,7 @@ inline class ListBuffer<T>(private val list: MutableList<T>) : MutableBuffer<T>
override fun iterator(): Iterator<T> = list.iterator() override fun iterator(): Iterator<T> = list.iterator()
override fun copy(): MutableBuffer<T> = ListBuffer(ArrayList(list)) override fun copy(): MutableBuffer<T> = MutableListBuffer(ArrayList(list))
} }
class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> { class ArrayBuffer<T>(private val array: Array<T>) : MutableBuffer<T> {

View File

@ -76,7 +76,7 @@ class DefaultStrides(override val shape: IntArray) : Strides {
override fun offset(index: IntArray): Int { override fun offset(index: IntArray): Int {
return index.mapIndexed { i, value -> return index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) { if (value < 0 || value >= shape[i]) {
throw RuntimeException("Index $value out of shape bounds: (0,${shape[i]})") throw RuntimeException("Index $value out of shape bounds: (0,${this.shape[i]})")
} }
value * strides[i] value * strides[i]
}.sum() }.sum()
@ -169,6 +169,6 @@ fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): Mutab
yield(initializer(it)) yield(initializer(it))
} }
} }
val buffer = ListBuffer<T>(sequence.toMutableList()) val buffer = MutableListBuffer<T>(sequence.toMutableList())
return MutableBufferNDStructure(strides, buffer) return MutableBufferNDStructure(strides, buffer)
} }