Feature/tensors performance #497

Closed
margarita0303 wants to merge 91 commits from feature/tensors-performance into feature/tensors-performance
2 changed files with 39 additions and 15 deletions
Showing only changes of commit 3a2faa7da4 - Show all commits

View File

@ -37,37 +37,53 @@ public class UniformHistogram1D<V : Any>(
} }
} }
/**
* An algebra for uniform histograms in 1D real space
*/
public class UniformHistogram1DGroup<V : Any, A>( public class UniformHistogram1DGroup<V : Any, A>(
public val valueAlgebra: A, public val valueAlgebra: A,
public val binSize: Double, public val binSize: Double,
public val startPoint: Double = 0.0, public val startPoint: Double = 0.0,
) : Group<UniformHistogram1D<V>>, ScaleOperations<UniformHistogram1D<V>> where A : Ring<V>, A : ScaleOperations<V> { ) : Group<Histogram1D<Double, V>>, ScaleOperations<Histogram1D<Double, V>> where A : Ring<V>, A : ScaleOperations<V> {
override val zero: UniformHistogram1D<V> by lazy { UniformHistogram1D(this, emptyMap()) } override val zero: UniformHistogram1D<V> by lazy { UniformHistogram1D(this, emptyMap()) }
public fun getIndex(at: Double): Int = floor((at - startPoint) / binSize).toInt() /**
* Get index of a bin
*/
@PublishedApi
internal fun getIndex(at: Double): Int = floor((at - startPoint) / binSize).toInt()
override fun add(left: UniformHistogram1D<V>, right: UniformHistogram1D<V>): UniformHistogram1D<V> = valueAlgebra { override fun add(
require(left.group == this@UniformHistogram1DGroup) left: Histogram1D<Double, V>,
require(right.group == this@UniformHistogram1DGroup) right: Histogram1D<Double, V>,
val keys = left.values.keys + right.values.keys ): UniformHistogram1D<V> = valueAlgebra {
val leftUniform = produceFrom(left)
val rightUniform = produceFrom(right)
val keys = leftUniform.values.keys + rightUniform.values.keys
UniformHistogram1D( UniformHistogram1D(
this@UniformHistogram1DGroup, this@UniformHistogram1DGroup,
keys.associateWith { (left.values[it] ?: valueAlgebra.zero) + (right.values[it] ?: valueAlgebra.zero) } keys.associateWith {
(leftUniform.values[it] ?: valueAlgebra.zero) + (rightUniform.values[it] ?: valueAlgebra.zero)
}
) )
} }
override fun UniformHistogram1D<V>.unaryMinus(): UniformHistogram1D<V> = valueAlgebra { override fun Histogram1D<Double, V>.unaryMinus(): UniformHistogram1D<V> = valueAlgebra {
UniformHistogram1D(this@UniformHistogram1DGroup, values.mapValues { -it.value }) UniformHistogram1D(this@UniformHistogram1DGroup, produceFrom(this@unaryMinus).values.mapValues { -it.value })
} }
override fun scale( override fun scale(
a: UniformHistogram1D<V>, a: Histogram1D<Double, V>,
value: Double, value: Double,
): UniformHistogram1D<V> = UniformHistogram1D( ): UniformHistogram1D<V> = UniformHistogram1D(
this@UniformHistogram1DGroup, this@UniformHistogram1DGroup,
a.values.mapValues { valueAlgebra.scale(it.value, value) } produceFrom(a).values.mapValues { valueAlgebra.scale(it.value, value) }
) )
/**
*
*/
public inline fun produce(block: Histogram1DBuilder<Double, V>.() -> Unit): UniformHistogram1D<V> { public inline fun produce(block: Histogram1DBuilder<Double, V>.() -> Unit): UniformHistogram1D<V> {
val map = HashMap<Int, V>() val map = HashMap<Int, V>()
val builder = object : Histogram1DBuilder<Double, V> { val builder = object : Histogram1DBuilder<Double, V> {
@ -87,7 +103,7 @@ public class UniformHistogram1DGroup<V : Any, A>(
* is increased by one. If not, all bins including values from this bin are increased by fraction * is increased by one. If not, all bins including values from this bin are increased by fraction
* (conserving the norming). * (conserving the norming).
*/ */
@UnstableKMathAPI @OptIn(UnstableKMathAPI::class)
public fun produceFrom(histogram: Histogram1D<Double, V>): UniformHistogram1D<V> = public fun produceFrom(histogram: Histogram1D<Double, V>): UniformHistogram1D<V> =
if ((histogram as? UniformHistogram1D)?.group == this) histogram if ((histogram as? UniformHistogram1D)?.group == this) histogram
else { else {

View File

@ -34,9 +34,17 @@ internal class UniformHistogram1DTest {
} }
@Test @Test
fun rebin() = runTest { fun rebinDown() = runTest {
val h1 = Histogram.uniform1D(DoubleField, 0.1).produce(generator.nextDoubleBuffer(10000)) val h1 = Histogram.uniform1D(DoubleField, 0.01).produce(generator.nextDoubleBuffer(10000))
val h2 = Histogram.uniform1D(DoubleField,0.3).produceFrom(h1) val h2 = Histogram.uniform1D(DoubleField,0.03).produceFrom(h1)
assertEquals(10000, h2.bins.sumOf { it.binValue }.toInt())
}
@Test
fun rebinUp() = runTest {
val h1 = Histogram.uniform1D(DoubleField, 0.03).produce(generator.nextDoubleBuffer(10000))
val h2 = Histogram.uniform1D(DoubleField,0.01).produceFrom(h1)
assertEquals(10000, h2.bins.sumOf { it.binValue }.toInt()) assertEquals(10000, h2.bins.sumOf { it.binValue }.toInt())
} }