Feature/tensors performance #497

Closed
margarita0303 wants to merge 91 commits from feature/tensors-performance into feature/tensors-performance
7 changed files with 49 additions and 39 deletions
Showing only changes of commit 3a3a5bd77f - Show all commits

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.3.3-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.1-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -38,5 +38,5 @@ public class DoubleDomain1D(
} }
@UnstableKMathAPI @UnstableKMathAPI
public val DoubleDomain1D.center: Double public val Domain1D<Double>.center: Double
get() = (range.endInclusive + range.start) / 2 get() = (range.endInclusive + range.start) / 2

View File

@ -17,7 +17,7 @@ public interface Bin<in T : Any, out V> : Domain<T> {
/** /**
* The value of this bin. * The value of this bin.
*/ */
public val value: V public val binValue: V
} }
public interface Histogram<in T : Any, out V, out B : Bin<T, V>> { public interface Histogram<in T : Any, out V, out B : Bin<T, V>> {

View File

@ -14,13 +14,13 @@ import space.kscience.kmath.structures.Buffer
/** /**
* A univariate bin based on a range * A univariate bin based on a range
* *
* @property value The value of histogram including weighting * @property binValue The value of histogram including weighting
* @property standardDeviation Standard deviation of the bin value. Zero or negative if not applicable * @property standardDeviation Standard deviation of the bin value. Zero or negative if not applicable
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public class Bin1D<T : Comparable<T>, out V>( public class Bin1D<T : Comparable<T>, out V>(
public val domain: Domain1D<T>, public val domain: Domain1D<T>,
override val value: V, override val binValue: V,
) : Bin<T, V>, ClosedRange<T> by domain.range { ) : Bin<T, V>, ClosedRange<T> by domain.range {
override val dimension: Int get() = 1 override val dimension: Int get() = 1

View File

@ -20,7 +20,7 @@ import space.kscience.kmath.operations.invoke
*/ */
public data class DomainBin<in T : Comparable<T>, out V>( public data class DomainBin<in T : Comparable<T>, out V>(
public val domain: Domain<T>, public val domain: Domain<T>,
override val value: V, override val binValue: V,
) : Bin<T, V>, Domain<T> by domain ) : Bin<T, V>, Domain<T> by domain
/** /**

View File

@ -21,7 +21,7 @@ internal class MultivariateHistogramTest {
val histogram = hSpace.produce { val histogram = hSpace.produce {
put(0.55, 0.55) put(0.55, 0.55)
} }
val bin = histogram.bins.find { it.value.toInt() > 0 } ?: fail() val bin = histogram.bins.find { it.binValue.toInt() > 0 } ?: fail()
assertTrue { bin.contains(DoubleVector(0.55, 0.55)) } assertTrue { bin.contains(DoubleVector(0.55, 0.55)) }
assertTrue { bin.contains(DoubleVector(0.6, 0.5)) } assertTrue { bin.contains(DoubleVector(0.6, 0.5)) }
assertFalse { bin.contains(DoubleVector(-0.55, 0.55)) } assertFalse { bin.contains(DoubleVector(-0.55, 0.55)) }
@ -44,7 +44,7 @@ internal class MultivariateHistogramTest {
put(nextDouble(), nextDouble(), nextDouble()) put(nextDouble(), nextDouble(), nextDouble())
} }
} }
assertEquals(n, histogram.bins.sumOf { it.value.toInt() }) assertEquals(n, histogram.bins.sumOf { it.binValue.toInt() })
} }
@Test @Test
@ -77,7 +77,7 @@ internal class MultivariateHistogramTest {
assertTrue { assertTrue {
res.bins.count() >= histogram1.bins.count() res.bins.count() >= histogram1.bins.count()
} }
assertEquals(0.0, res.bins.sumOf { it.value.toDouble() }) assertEquals(0.0, res.bins.sumOf { it.binValue.toDouble() })
} }
} }
} }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.histogram package space.kscience.kmath.histogram
import space.kscience.kmath.domains.DoubleDomain1D import space.kscience.kmath.domains.DoubleDomain1D
import space.kscience.kmath.domains.center
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.Group import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
@ -26,18 +27,24 @@ private fun <B : ClosedRange<Double>> TreeMap<Double, B>.getBin(value: Double):
return null return null
} }
public data class ValueAndError(val value: Double, val error: Double)
public typealias WeightedBin1D = Bin1D<Double, ValueAndError>
@UnstableKMathAPI @UnstableKMathAPI
public class TreeHistogram( public class TreeHistogram(
private val binMap: TreeMap<Double, out Bin1D<Double, Double>>, private val binMap: TreeMap<Double, WeightedBin1D>,
) : Histogram1D<Double, Double> { ) : Histogram1D<Double, ValueAndError> {
override fun get(value: Double): Bin1D<Double, Double>? = binMap.getBin(value) override fun get(value: Double): WeightedBin1D? = binMap.getBin(value)
override val bins: Collection<Bin1D<Double, Double>> get() = binMap.values override val bins: Collection<WeightedBin1D> get() = binMap.values
} }
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
@PublishedApi @PublishedApi
internal class TreeHistogramBuilder(val binFactory: (Double) -> DoubleDomain1D) : Histogram1DBuilder<Double, Double> { internal class TreeHistogramBuilder(val binFactory: (Double) -> DoubleDomain1D) : Histogram1DBuilder<Double, Double> {
override val defaultValue: Double get() = 1.0
internal class BinCounter(val domain: DoubleDomain1D, val counter: Counter<Double> = Counter.double()) : internal class BinCounter(val domain: DoubleDomain1D, val counter: Counter<Double> = Counter.double()) :
ClosedRange<Double> by domain.range ClosedRange<Double> by domain.range
@ -46,7 +53,7 @@ internal class TreeHistogramBuilder(val binFactory: (Double) -> DoubleDomain1D)
fun get(value: Double): BinCounter? = bins.getBin(value) fun get(value: Double): BinCounter? = bins.getBin(value)
fun createBin(value: Double): BinCounter { fun createBin(value: Double): BinCounter {
val binDefinition = binFactory(value) val binDefinition: DoubleDomain1D = binFactory(value)
val newBin = BinCounter(binDefinition) val newBin = BinCounter(binDefinition)
synchronized(this) { synchronized(this) {
bins[binDefinition.center] = newBin bins[binDefinition.center] = newBin
@ -69,9 +76,9 @@ internal class TreeHistogramBuilder(val binFactory: (Double) -> DoubleDomain1D)
} }
fun build(): TreeHistogram { fun build(): TreeHistogram {
val map = bins.mapValuesTo(TreeMap<Double, Bin1D<Double,Double>>()) { (_, binCounter) -> val map = bins.mapValuesTo(TreeMap<Double, WeightedBin1D>()) { (_, binCounter) ->
val count = binCounter.counter.value val count: Double = binCounter.counter.value
Bin1D(binCounter.domain, count, sqrt(count)) WeightedBin1D(binCounter.domain, ValueAndError(count, sqrt(count)))
} }
return TreeHistogram(map) return TreeHistogram(map)
} }
@ -83,26 +90,27 @@ internal class TreeHistogramBuilder(val binFactory: (Double) -> DoubleDomain1D)
@UnstableKMathAPI @UnstableKMathAPI
public class TreeHistogramSpace( public class TreeHistogramSpace(
@PublishedApi internal val binFactory: (Double) -> DoubleDomain1D, @PublishedApi internal val binFactory: (Double) -> DoubleDomain1D,
) : Group<Histogram1D<Double,Double>>, ScaleOperations<Histogram1D<Double,Double>> { ) : Group<TreeHistogram>, ScaleOperations<TreeHistogram> {
public inline fun fill(block: Histogram1DBuilder<Double,Double>.() -> Unit): Histogram1D<Double,Double> = public inline fun fill(block: Histogram1DBuilder<Double, Double>.() -> Unit): TreeHistogram =
TreeHistogramBuilder(binFactory).apply(block).build() TreeHistogramBuilder(binFactory).apply(block).build()
override fun add( override fun add(
left: Histogram1D<Double,Double>, left: TreeHistogram,
right: Histogram1D<Double,Double>, right: TreeHistogram,
): Histogram1D<Double,Double> { ): TreeHistogram {
// require(a.context == this) { "Histogram $a does not belong to this context" } // require(a.context == this) { "Histogram $a does not belong to this context" }
// require(b.context == this) { "Histogram $b does not belong to this context" } // require(b.context == this) { "Histogram $b does not belong to this context" }
val bins = TreeMap<Double, Bin1D<Double,Double>>().apply { val bins = TreeMap<Double, WeightedBin1D>().apply {
(left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def -> (left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def ->
put( put(
def.center, def.center,
Bin1D( WeightedBin1D(
def, def,
value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0), ValueAndError(
standardDeviation = (left[def.center]?.standardDeviation (left[def.center]?.binValue?.value ?: 0.0) + (right[def.center]?.binValue?.value ?: 0.0),
?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0) (left[def.center]?.binValue?.error ?: 0.0) + (right[def.center]?.binValue?.error ?: 0.0)
)
) )
) )
} }
@ -110,15 +118,17 @@ public class TreeHistogramSpace(
return TreeHistogram(bins) return TreeHistogram(bins)
} }
override fun scale(a: Histogram1D<Double,Double>, value: Double): Histogram1D<Double,Double> { override fun scale(a: TreeHistogram, value: Double): TreeHistogram {
val bins = TreeMap<Double, Bin1D<Double,Double>>().apply { val bins = TreeMap<Double, WeightedBin1D>().apply {
a.bins.forEach { bin -> a.bins.forEach { bin ->
put( put(
bin.domain.center, bin.domain.center,
Bin1D( WeightedBin1D(
bin.domain, bin.domain,
value = bin.value * value, ValueAndError(
standardDeviation = abs(bin.standardDeviation * value) bin.binValue.value * value,
abs(bin.binValue.error * value)
)
) )
) )
} }
@ -127,19 +137,19 @@ public class TreeHistogramSpace(
return TreeHistogram(bins) return TreeHistogram(bins)
} }
override fun Histogram1D<Double,Double>.unaryMinus(): Histogram1D<Double,Double> = this * (-1) override fun TreeHistogram.unaryMinus(): TreeHistogram = this * (-1)
override val zero: Histogram1D<Double,Double> by lazy { fill { } } override val zero: TreeHistogram by lazy { fill { } }
public companion object { public companion object {
/** /**
* Build and fill a [DoubleHistogram1D]. Returns a read-only histogram. * Build and fill a [TreeHistogram]. Returns a read-only histogram.
*/ */
public inline fun uniform( public inline fun uniform(
binSize: Double, binSize: Double,
start: Double = 0.0, start: Double = 0.0,
builder: Histogram1DBuilder<Double, Double>.() -> Unit, builder: Histogram1DBuilder<Double, Double>.() -> Unit,
): Histogram1D<Double,Double> = uniform(binSize, start).fill(builder) ): TreeHistogram = uniform(binSize, start).fill(builder)
/** /**
* Build and fill a histogram with custom borders. Returns a read-only histogram. * Build and fill a histogram with custom borders. Returns a read-only histogram.
@ -147,7 +157,7 @@ public class TreeHistogramSpace(
public inline fun custom( public inline fun custom(
borders: DoubleArray, borders: DoubleArray,
builder: Histogram1DBuilder<Double, Double>.() -> Unit, builder: Histogram1DBuilder<Double, Double>.() -> Unit,
): Histogram1D<Double,Double> = custom(borders).fill(builder) ): TreeHistogram = custom(borders).fill(builder)
/** /**