Feature/tensors performance #497

Closed
margarita0303 wants to merge 91 commits from feature/tensors-performance into feature/tensors-performance
3 changed files with 11 additions and 10 deletions
Showing only changes of commit 27a252b637 - Show all commits

View File

@ -11,16 +11,17 @@ import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.indices import space.kscience.kmath.structures.indices
/** /**
* * A hyper-square (or hyper-cube) real-space domain. It is formed by a [Buffer] of [lower] boundaries
* HyperSquareDomain class. * and a [Buffer] of upper boundaries. Upper should be greater or equals than lower.
*
* @author Alexander Nozik
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public class HyperSquareDomain(public val lower: Buffer<Double>, public val upper: Buffer<Double>) : DoubleDomain { public class HyperSquareDomain(public val lower: Buffer<Double>, public val upper: Buffer<Double>) : DoubleDomain {
init { init {
require(lower.size == upper.size) { require(lower.size == upper.size) {
"Domain borders size mismatch. Lower borders size is ${lower.size}, but upper borders size is ${upper.size}" "Domain borders size mismatch. Lower borders size is ${lower.size}, but upper borders size is ${upper.size}."
}
require(lower.indices.all { lower[it] <= upper[it] }) {
"Domain borders order mismatch. Lower borders must be less or equals than upper borders."
} }
} }

View File

@ -45,7 +45,7 @@ public class DoubleHistogramGroup(
else -> floor((value - lower[axis]) / binSize[axis]).toInt() else -> floor((value - lower[axis]) / binSize[axis]).toInt()
} }
override fun getIndex(point: Buffer<Double>): IntArray = IntArray(dimension) { override fun getIndexOrNull(point: Buffer<Double>): IntArray = IntArray(dimension) {
getIndex(it, point[it]) getIndex(it, point[it])
} }
@ -82,7 +82,7 @@ public class DoubleHistogramGroup(
override val defaultValue: Double get() = 1.0 override val defaultValue: Double get() = 1.0
override fun putValue(point: Point<out Double>, value: Double) { override fun putValue(point: Point<out Double>, value: Double) {
val index = getIndex(point) val index = getIndexOrNull(point)
ndCounter[index].add(value) ndCounter[index].add(value)
} }
} }

View File

@ -33,7 +33,7 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
) : Histogram<T, V, DomainBin<T, V>> { ) : Histogram<T, V, DomainBin<T, V>> {
override fun get(point: Point<T>): DomainBin<T, V>? { override fun get(point: Point<T>): DomainBin<T, V>? {
val index = histogramGroup.getIndex(point) ?: return null val index = histogramGroup.getIndexOrNull(point) ?: return null
return histogramGroup.produceBin(index, values[index]) return histogramGroup.produceBin(index, values[index])
} }
@ -54,9 +54,9 @@ public interface IndexedHistogramGroup<T : Comparable<T>, V : Any> : Group<Index
public val histogramValueAlgebra: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape), public val histogramValueAlgebra: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
/** /**
* Resolve index of the bin including given [point] * Resolve index of the bin including given [point]. Return null if point is outside histogram area
*/ */
public fun getIndex(point: Point<T>): IntArray? public fun getIndexOrNull(point: Point<T>): IntArray?
/** /**
* Get a bin domain represented by given index * Get a bin domain represented by given index