[WIP] Another histogram refactor

This commit is contained in:
Alexander Nozik 2022-04-05 23:23:29 +03:00
parent ce82d2d076
commit 73f72f12bc
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
16 changed files with 258 additions and 67 deletions

View File

@ -48,6 +48,7 @@
- Operations -> Ops - Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. - Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
- Tensor algebra takes read-only structures as input and inherits AlgebraND - Tensor algebra takes read-only structures as input and inherits AlgebraND
- `UnivariateDistribution` renamed to `Distribution1D`
### Deprecated ### Deprecated
- Specialized `DoubleBufferAlgebra` - Specialized `DoubleBufferAlgebra`

View File

@ -35,6 +35,23 @@ public class DoubleDomain1D(
} }
override fun volume(): Double = range.endInclusive - range.start override fun volume(): Double = range.endInclusive - range.start
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as DoubleDomain1D
if (doubleRange != other.doubleRange) return false
return true
}
override fun hashCode(): Int = doubleRange.hashCode()
override fun toString(): String = doubleRange.toString()
} }
@UnstableKMathAPI @UnstableKMathAPI

View File

@ -3,43 +3,71 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/ */
package space.kscience.kmath.misc package space.kscience.kmath.misc
import kotlin.comparisons.*
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.VirtualBuffer
/** /**
* Return a new list filled with buffer indices. Indice order is defined by sorting associated buffer value. * Return a new array filled with buffer indices. Indices order is defined by sorting associated buffer value.
* This feature allows to sort buffer values without reordering its content. * This feature allows sorting buffer values without reordering its content.
* *
* @return List of buffer indices, sorted by associated value. * @return Buffer indices, sorted by associated value.
*/ */
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V: Comparable<V>> Buffer<V>.permSort() : IntArray = _permSortWith(compareBy<Int> { get(it) }) public fun <V : Comparable<V>> Buffer<V>.indicesSorted(): IntArray = permSortIndicesWith(compareBy { get(it) })
/**
* Create a zero-copy virtual buffer that contains the same elements but in ascending order
*/
@OptIn(UnstableKMathAPI::class)
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
val permutations = indicesSorted()
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V: Comparable<V>> Buffer<V>.permSortDescending() : IntArray = _permSortWith(compareByDescending<Int> { get(it) }) public fun <V : Comparable<V>> Buffer<V>.indicesSortedDescending(): IntArray =
permSortIndicesWith(compareByDescending { get(it) })
/**
* Create a zero-copy virtual buffer that contains the same elements but in descending order
*/
@OptIn(UnstableKMathAPI::class)
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
val permutations = indicesSortedDescending()
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, C: Comparable<C>> Buffer<V>.permSortBy(selector: (V) -> C) : IntArray = _permSortWith(compareBy<Int> { selector(get(it)) }) public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
permSortIndicesWith(compareBy { selector(get(it)) })
@OptIn(UnstableKMathAPI::class)
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
val permutations = indicesSortedBy(selector)
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V, C: Comparable<C>> Buffer<V>.permSortByDescending(selector: (V) -> C) : IntArray = _permSortWith(compareByDescending<Int> { selector(get(it)) }) public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
permSortIndicesWith(compareByDescending { selector(get(it)) })
@OptIn(UnstableKMathAPI::class)
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
val permutations = indicesSortedByDescending(selector)
return VirtualBuffer(size) { this[permutations[it]] }
}
@PerformancePitfall
@UnstableKMathAPI @UnstableKMathAPI
public fun <V> Buffer<V>.permSortWith(comparator : Comparator<V>) : IntArray = _permSortWith { i1, i2 -> comparator.compare(get(i1), get(i2)) } public fun <V> Buffer<V>.indicesSortedWith(comparator: Comparator<V>): IntArray =
permSortIndicesWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
@PerformancePitfall private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {
@UnstableKMathAPI if (size < 2) return IntArray(size) { 0 }
private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray {
if (size < 2) return IntArray(size)
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indice /* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indices
* arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire * arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire
* cached array, then fill remaining indices manually. Not done for now, because: * cached array, then fill remaining indices manually. Not done for now, because:
* 1. doing it right would require some statistics about common used buffer sizes. * 1. doing it right would require some statistics about common used buffer sizes.
@ -53,3 +81,12 @@ private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray
*/ */
return packedIndices.sortedWith(comparator).toIntArray() return packedIndices.sortedWith(comparator).toIntArray()
} }
/**
* Checks that the [Buffer] is sorted (ascending) and throws [IllegalArgumentException] if it is not.
*/
public fun <T : Comparable<T>> Buffer<T>.requireSorted() {
for (i in 0..(size - 2)) {
require(get(i + 1) >= get(i)) { "The buffer is not sorted at index $i" }
}
}

View File

@ -6,14 +6,13 @@
package space.kscience.kmath.misc package space.kscience.kmath.misc
import space.kscience.kmath.misc.PermSortTest.Platform.* import space.kscience.kmath.misc.PermSortTest.Platform.*
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
import space.kscience.kmath.structures.IntBuffer import space.kscience.kmath.structures.IntBuffer
import space.kscience.kmath.structures.asBuffer import space.kscience.kmath.structures.asBuffer
import kotlin.random.Random
import kotlin.test.Test
import kotlin.test.assertContentEquals import kotlin.test.assertContentEquals
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class PermSortTest { class PermSortTest {
@ -29,9 +28,9 @@ class PermSortTest {
@Test @Test
fun testOnEmptyBuffer() { fun testOnEmptyBuffer() {
val emptyBuffer = IntBuffer(0) {it} val emptyBuffer = IntBuffer(0) {it}
var permutations = emptyBuffer.permSort() var permutations = emptyBuffer.indicesSorted()
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result") assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
permutations = emptyBuffer.permSortDescending() permutations = emptyBuffer.indicesSortedDescending()
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result") assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
} }
@ -47,25 +46,25 @@ class PermSortTest {
@Test @Test
fun testPermSortBy() { fun testPermSortBy() {
val permutations = platforms.permSortBy { it.name } val permutations = platforms.indicesSortedBy { it.name }
val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM) val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM)
assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name") assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name")
} }
@Test @Test
fun testPermSortByDescending() { fun testPermSortByDescending() {
val permutations = platforms.permSortByDescending { it.name } val permutations = platforms.indicesSortedByDescending { it.name }
val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID) val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID)
assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name") assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name")
} }
@Test @Test
fun testPermSortWith() { fun testPermSortWith() {
var permutations = platforms.permSortWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) } var permutations = platforms.indicesSortedWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID) val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID)
assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator") assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator")
permutations = platforms.permSortWith(compareByDescending { it.name.length }) permutations = platforms.indicesSortedWith(compareByDescending { it.name.length })
assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator") assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator")
} }
@ -75,7 +74,7 @@ class PermSortTest {
println("Test randomization seed: $seed") println("Test randomization seed: $seed")
val buffer = Random(seed).buffer(bufferSize) val buffer = Random(seed).buffer(bufferSize)
val indices = buffer.permSort() val indices = buffer.indicesSorted()
assertEquals(bufferSize, indices.size) assertEquals(bufferSize, indices.size)
// Ensure no doublon is present in indices // Ensure no doublon is present in indices
@ -87,7 +86,7 @@ class PermSortTest {
assertTrue(current <= next, "Permutation indices not properly sorted") assertTrue(current <= next, "Permutation indices not properly sorted")
} }
val descIndices = buffer.permSortDescending() val descIndices = buffer.indicesSortedDescending()
assertEquals(bufferSize, descIndices.size) assertEquals(bufferSize, descIndices.size)
// Ensure no doublon is present in indices // Ensure no doublon is present in indices
assertEquals(descIndices.toSet().size, descIndices.size) assertEquals(descIndices.toSet().size, descIndices.size)

View File

@ -21,19 +21,22 @@ readme {
feature( feature(
id = "DoubleVector", id = "DoubleVector",
description = "Numpy-like operations for Buffers/Points",
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt"
) ){
"Numpy-like operations for Buffers/Points"
}
feature( feature(
id = "DoubleMatrix", id = "DoubleMatrix",
description = "Numpy-like operations for 2d real structures",
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt"
) ){
"Numpy-like operations for 2d real structures"
}
feature( feature(
id = "grids", id = "grids",
description = "Uniform grid generators",
ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt" ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt"
) ){
"Uniform grid generators"
}
} }

View File

@ -17,6 +17,8 @@ kotlin.sourceSets {
commonTest { commonTest {
dependencies { dependencies {
implementation(project(":kmath-for-real")) implementation(project(":kmath-for-real"))
implementation(projects.kmath.kmathStat)
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.0")
} }
} }
} }

View File

@ -16,11 +16,11 @@ import kotlin.math.floor
/** /**
* Multivariate histogram space for hyper-square real-field bins. * Multivariate histogram space for hyper-square real-field bins.
*/ */
public class DoubleHistogramSpace( public class DoubleHistogramGroup(
private val lower: Buffer<Double>, private val lower: Buffer<Double>,
private val upper: Buffer<Double>, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 }, private val binNums: IntArray = IntArray(lower.size) { 20 },
) : IndexedHistogramSpace<Double, Double> { ) : IndexedHistogramGroup<Double, Double> {
init { init {
// argument checks // argument checks
@ -105,7 +105,7 @@ public class DoubleHistogramSpace(
*/ */
public fun fromRanges( public fun fromRanges(
vararg ranges: ClosedFloatingPointRange<Double>, vararg ranges: ClosedFloatingPointRange<Double>,
): DoubleHistogramSpace = DoubleHistogramSpace( ): DoubleHistogramGroup = DoubleHistogramGroup(
ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(), ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(),
ranges.map(ClosedFloatingPointRange<Double>::endInclusive).asBuffer() ranges.map(ClosedFloatingPointRange<Double>::endInclusive).asBuffer()
) )
@ -121,7 +121,7 @@ public class DoubleHistogramSpace(
*/ */
public fun fromRanges( public fun fromRanges(
vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>, vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>,
): DoubleHistogramSpace = DoubleHistogramSpace( ): DoubleHistogramGroup = DoubleHistogramGroup(
ListBuffer( ListBuffer(
ranges ranges
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first) .map(Pair<ClosedFloatingPointRange<Double>, Int>::first)

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.histogram package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain1D import space.kscience.kmath.domains.Domain1D
import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.asSequence import space.kscience.kmath.operations.asSequence
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -18,7 +19,7 @@ import space.kscience.kmath.structures.Buffer
* @property standardDeviation Standard deviation of the bin value. Zero or negative if not applicable * @property standardDeviation Standard deviation of the bin value. Zero or negative if not applicable
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public class Bin1D<T : Comparable<T>, out V>( public data class Bin1D<T : Comparable<T>, out V>(
public val domain: Domain1D<T>, public val domain: Domain1D<T>,
override val binValue: V, override val binValue: V,
) : Bin<T, V>, ClosedRange<T> by domain.range { ) : Bin<T, V>, ClosedRange<T> by domain.range {
@ -35,12 +36,15 @@ public interface Histogram1D<T : Comparable<T>, V> : Histogram<T, V, Bin1D<T, V>
override operator fun get(point: Buffer<T>): Bin1D<T, V>? = get(point[0]) override operator fun get(point: Buffer<T>): Bin1D<T, V>? = get(point[0])
} }
@UnstableKMathAPI
public interface Histogram1DBuilder<in T : Any, V : Any> : HistogramBuilder<T, V> { public interface Histogram1DBuilder<in T : Any, V : Any> : HistogramBuilder<T, V> {
/** /**
* Thread safe put operation * Thread safe put operation
*/ */
public fun putValue(at: T, value: V = defaultValue) public fun putValue(at: T, value: V = defaultValue)
override fun putValue(point: Point<out T>, value: V) {
putValue(point[0], value)
}
} }
@UnstableKMathAPI @UnstableKMathAPI
@ -52,5 +56,5 @@ public fun Histogram1DBuilder<Double, *>.fill(array: DoubleArray): Unit =
array.forEach(this::putValue) array.forEach(this::putValue)
@UnstableKMathAPI @UnstableKMathAPI
public fun Histogram1DBuilder<Double, *>.fill(buffer: Buffer<Double>): Unit = public fun <T: Any> Histogram1DBuilder<T, *>.fill(buffer: Buffer<T>): Unit =
buffer.asSequence().forEach(this::putValue) buffer.asSequence().forEach(this::putValue)

View File

@ -28,7 +28,7 @@ public data class DomainBin<in T : Comparable<T>, out V>(
* @param V the type of bin value * @param V the type of bin value
*/ */
public class IndexedHistogram<T : Comparable<T>, V : Any>( public class IndexedHistogram<T : Comparable<T>, V : Any>(
public val histogramSpace: IndexedHistogramSpace<T, V>, public val histogramSpace: IndexedHistogramGroup<T, V>,
public val values: StructureND<V>, public val values: StructureND<V>,
) : Histogram<T, V, DomainBin<T, V>> { ) : Histogram<T, V, DomainBin<T, V>> {
@ -48,8 +48,8 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
/** /**
* A space for producing histograms with values in a NDStructure * A space for producing histograms with values in a NDStructure
*/ */
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any> public interface IndexedHistogramGroup<T : Comparable<T>, V : Any> : Group<IndexedHistogram<T, V>>,
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> { ScaleOperations<IndexedHistogram<T, V>> {
public val shape: Shape public val shape: Shape
public val histogramValueAlgebra: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape), public val histogramValueAlgebra: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),

View File

@ -5,8 +5,92 @@
package space.kscience.kmath.histogram package space.kscience.kmath.histogram
//class UniformHistogram1D( import space.kscience.kmath.domains.DoubleDomain1D
// public val borders: Buffer<Double>, import space.kscience.kmath.misc.UnstableKMathAPI
// public val values: Buffer<Long>, import space.kscience.kmath.operations.Group
//) : Histogram1D<Double, Long> { import space.kscience.kmath.operations.Ring
//} import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import kotlin.math.floor
@OptIn(UnstableKMathAPI::class)
public class UniformHistogram1D<V : Any>(
public val group: UniformHistogram1DGroup<V, *>,
public val values: Map<Int, V>,
) : Histogram1D<Double, V> {
private val startPoint get() = group.startPoint
private val binSize get() = group.binSize
private fun produceBin(index: Int, value: V): Bin1D<Double, V> {
val domain = DoubleDomain1D((startPoint + index * binSize)..(startPoint + (index + 1) * binSize))
return Bin1D(domain, value)
}
override val bins: Iterable<Bin1D<Double, V>> get() = values.map { produceBin(it.key, it.value) }
override fun get(value: Double): Bin1D<Double, V>? {
val index: Int = group.getIndex(value)
val v = values[index]
return v?.let { produceBin(index, it) }
}
}
public class UniformHistogram1DGroup<V : Any, A>(
public val valueAlgebra: A,
public val binSize: Double,
public val startPoint: Double = 0.0,
) : Group<UniformHistogram1D<V>>, ScaleOperations<UniformHistogram1D<V>> where A : Ring<V>, A : ScaleOperations<V> {
override val zero: UniformHistogram1D<V> by lazy { UniformHistogram1D(this, emptyMap()) }
public fun getIndex(at: Double): Int = floor((at - startPoint) / binSize).toInt()
override fun add(left: UniformHistogram1D<V>, right: UniformHistogram1D<V>): UniformHistogram1D<V> = valueAlgebra {
require(left.group == this@UniformHistogram1DGroup)
require(right.group == this@UniformHistogram1DGroup)
val keys = left.values.keys + right.values.keys
UniformHistogram1D(
this@UniformHistogram1DGroup,
keys.associateWith { (left.values[it] ?: valueAlgebra.zero) + (right.values[it] ?: valueAlgebra.zero) }
)
}
override fun UniformHistogram1D<V>.unaryMinus(): UniformHistogram1D<V> = valueAlgebra {
UniformHistogram1D(this@UniformHistogram1DGroup, values.mapValues { -it.value })
}
override fun scale(
a: UniformHistogram1D<V>,
value: Double,
): UniformHistogram1D<V> = UniformHistogram1D(
this@UniformHistogram1DGroup,
a.values.mapValues { valueAlgebra.scale(it.value, value) }
)
public inline fun produce(block: Histogram1DBuilder<Double, V>.() -> Unit): UniformHistogram1D<V> {
val map = HashMap<Int, V>()
val builder = object : Histogram1DBuilder<Double, V> {
override val defaultValue: V get() = valueAlgebra.zero
override fun putValue(at: Double, value: V) {
val index = getIndex(at)
map[index] = with(valueAlgebra) { (map[index] ?: zero) + one }
}
}
builder.block()
return UniformHistogram1D(this, map)
}
}
public fun <V : Any, A> Histogram.Companion.uniform1D(
algebra: A,
binSize: Double,
startPoint: Double = 0.0,
): UniformHistogram1DGroup<V, A> where A : Ring<V>, A : ScaleOperations<V> =
UniformHistogram1DGroup(algebra, binSize, startPoint)
@UnstableKMathAPI
public fun <V : Any> UniformHistogram1DGroup<V, *>.produce(
buffer: Buffer<Double>,
): UniformHistogram1D<V> = produce { fill(buffer) }

View File

@ -14,7 +14,7 @@ import kotlin.test.*
internal class MultivariateHistogramTest { internal class MultivariateHistogramTest {
@Test @Test
fun testSinglePutHistogram() { fun testSinglePutHistogram() {
val hSpace = DoubleHistogramSpace.fromRanges( val hSpace = DoubleHistogramGroup.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)
) )
@ -29,7 +29,7 @@ internal class MultivariateHistogramTest {
@Test @Test
fun testSequentialPut() { fun testSequentialPut() {
val hSpace = DoubleHistogramSpace.fromRanges( val hSpace = DoubleHistogramGroup.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)
@ -49,7 +49,7 @@ internal class MultivariateHistogramTest {
@Test @Test
fun testHistogramAlgebra() { fun testHistogramAlgebra() {
DoubleHistogramSpace.fromRanges( DoubleHistogramGroup.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)

View File

@ -0,0 +1,34 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.histogram
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runTest
import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.stat.RandomGenerator
import space.kscience.kmath.stat.nextBuffer
import kotlin.test.Test
import kotlin.test.assertEquals
@OptIn(ExperimentalCoroutinesApi::class, UnstableKMathAPI::class)
internal class UniformHistogram1DTest {
@Test
fun normal() = runTest {
val generator = RandomGenerator.default(123)
val distribution = NormalDistribution(0.0, 1.0)
with(Histogram.uniform1D(DoubleField, 0.1)) {
val h1 = produce(distribution.nextBuffer(generator, 10000))
val h2 = produce(distribution.nextBuffer(generator, 50000))
val h3 = h1 + h2
assertEquals(60000, h3.bins.sumOf { it.binValue }.toInt())
}
}
}

View File

@ -27,7 +27,7 @@ public interface Distribution<T : Any> : Sampler<T> {
public companion object public companion object
} }
public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> { public interface Distribution1D<T : Comparable<T>> : Distribution<T> {
/** /**
* Cumulative distribution for ordered parameter (CDF) * Cumulative distribution for ordered parameter (CDF)
*/ */
@ -37,7 +37,7 @@ public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
/** /**
* Compute probability integral in an interval * Compute probability integral in an interval
*/ */
public fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double { public fun <T : Comparable<T>> Distribution1D<T>.integral(from: T, to: T): Double {
require(to > from) require(to > from)
return cumulative(to) - cumulative(from) return cumulative(to) - cumulative(from)
} }

View File

@ -14,14 +14,9 @@ import space.kscience.kmath.stat.RandomGenerator
import kotlin.math.* import kotlin.math.*
/** /**
* Implements [UnivariateDistribution] for the normal (gaussian) distribution. * Implements [Distribution1D] for the normal (gaussian) distribution.
*/ */
public class NormalDistribution(public val sampler: GaussianSampler) : UnivariateDistribution<Double> { public class NormalDistribution(public val sampler: GaussianSampler) : Distribution1D<Double> {
public constructor(
mean: Double,
standardDeviation: Double,
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler,
) : this(GaussianSampler(mean, standardDeviation, normalized))
override fun probability(arg: Double): Double { override fun probability(arg: Double): Double {
val x1 = (arg - sampler.mean) / sampler.standardDeviation val x1 = (arg - sampler.mean) / sampler.standardDeviation
@ -43,3 +38,9 @@ public class NormalDistribution(public val sampler: GaussianSampler) : Univariat
private val SQRT2 = sqrt(2.0) private val SQRT2 = sqrt(2.0)
} }
} }
public fun NormalDistribution(
mean: Double,
standardDeviation: Double,
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler,
): NormalDistribution = NormalDistribution(GaussianSampler(mean, standardDeviation, normalized))

View File

@ -67,3 +67,12 @@ public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int):
@JvmName("sampleIntBuffer") @JvmName("sampleIntBuffer")
public fun Sampler<Int>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Int>> = public fun Sampler<Int>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Int>> =
sampleBuffer(generator, size, ::IntBuffer) sampleBuffer(generator, size, ::IntBuffer)
/**
* Samples a [Buffer] of values from this [Sampler].
*/
public suspend fun Sampler<Double>.nextBuffer(generator: RandomGenerator, size: Int): Buffer<Double> =
sampleBuffer(generator, size).first()
//TODO add `context(RandomGenerator) Sampler.nextBuffer

View File

@ -8,9 +8,9 @@ package space.kscience.kmath.stat
import space.kscience.kmath.chains.Chain import space.kscience.kmath.chains.Chain
import space.kscience.kmath.chains.SimpleChain import space.kscience.kmath.chains.SimpleChain
import space.kscience.kmath.distributions.Distribution import space.kscience.kmath.distributions.Distribution
import space.kscience.kmath.distributions.UnivariateDistribution import space.kscience.kmath.distributions.Distribution1D
public class UniformDistribution(public val range: ClosedFloatingPointRange<Double>) : UnivariateDistribution<Double> { public class UniformDistribution(public val range: ClosedFloatingPointRange<Double>) : Distribution1D<Double> {
private val length: Double = range.endInclusive - range.start private val length: Double = range.endInclusive - range.start
override fun probability(arg: Double): Double = if (arg in range) 1.0 / length else 0.0 override fun probability(arg: Double): Double = if (arg in range) 1.0 / length else 0.0