[WIP] Another histogram refactor
This commit is contained in:
parent
ce82d2d076
commit
73f72f12bc
@ -48,6 +48,7 @@
|
||||
- Operations -> Ops
|
||||
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
|
||||
- Tensor algebra takes read-only structures as input and inherits AlgebraND
|
||||
- `UnivariateDistribution` renamed to `Distribution1D`
|
||||
|
||||
### Deprecated
|
||||
- Specialized `DoubleBufferAlgebra`
|
||||
|
@ -35,6 +35,23 @@ public class DoubleDomain1D(
|
||||
}
|
||||
|
||||
override fun volume(): Double = range.endInclusive - range.start
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other == null || this::class != other::class) return false
|
||||
|
||||
other as DoubleDomain1D
|
||||
|
||||
if (doubleRange != other.doubleRange) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = doubleRange.hashCode()
|
||||
|
||||
override fun toString(): String = doubleRange.toString()
|
||||
|
||||
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
|
@ -3,43 +3,71 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
|
||||
package space.kscience.kmath.misc
|
||||
|
||||
import kotlin.comparisons.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.VirtualBuffer
|
||||
|
||||
/**
|
||||
* Return a new list filled with buffer indices. Indice order is defined by sorting associated buffer value.
|
||||
* This feature allows to sort buffer values without reordering its content.
|
||||
* Return a new array filled with buffer indices. Indices order is defined by sorting associated buffer value.
|
||||
* This feature allows sorting buffer values without reordering its content.
|
||||
*
|
||||
* @return List of buffer indices, sorted by associated value.
|
||||
* @return Buffer indices, sorted by associated value.
|
||||
*/
|
||||
@PerformancePitfall
|
||||
@UnstableKMathAPI
|
||||
public fun <V: Comparable<V>> Buffer<V>.permSort() : IntArray = _permSortWith(compareBy<Int> { get(it) })
|
||||
public fun <V : Comparable<V>> Buffer<V>.indicesSorted(): IntArray = permSortIndicesWith(compareBy { get(it) })
|
||||
|
||||
/**
|
||||
* Create a zero-copy virtual buffer that contains the same elements but in ascending order
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun <V : Comparable<V>> Buffer<V>.sorted(): Buffer<V> {
|
||||
val permutations = indicesSorted()
|
||||
return VirtualBuffer(size) { this[permutations[it]] }
|
||||
}
|
||||
|
||||
@PerformancePitfall
|
||||
@UnstableKMathAPI
|
||||
public fun <V: Comparable<V>> Buffer<V>.permSortDescending() : IntArray = _permSortWith(compareByDescending<Int> { get(it) })
|
||||
public fun <V : Comparable<V>> Buffer<V>.indicesSortedDescending(): IntArray =
|
||||
permSortIndicesWith(compareByDescending { get(it) })
|
||||
|
||||
/**
|
||||
* Create a zero-copy virtual buffer that contains the same elements but in descending order
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun <V : Comparable<V>> Buffer<V>.sortedDescending(): Buffer<V> {
|
||||
val permutations = indicesSortedDescending()
|
||||
return VirtualBuffer(size) { this[permutations[it]] }
|
||||
}
|
||||
|
||||
@PerformancePitfall
|
||||
@UnstableKMathAPI
|
||||
public fun <V, C: Comparable<C>> Buffer<V>.permSortBy(selector: (V) -> C) : IntArray = _permSortWith(compareBy<Int> { selector(get(it)) })
|
||||
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedBy(selector: (V) -> C): IntArray =
|
||||
permSortIndicesWith(compareBy { selector(get(it)) })
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun <V, C : Comparable<C>> Buffer<V>.sortedBy(selector: (V) -> C): Buffer<V> {
|
||||
val permutations = indicesSortedBy(selector)
|
||||
return VirtualBuffer(size) { this[permutations[it]] }
|
||||
}
|
||||
|
||||
@PerformancePitfall
|
||||
@UnstableKMathAPI
|
||||
public fun <V, C: Comparable<C>> Buffer<V>.permSortByDescending(selector: (V) -> C) : IntArray = _permSortWith(compareByDescending<Int> { selector(get(it)) })
|
||||
public fun <V, C : Comparable<C>> Buffer<V>.indicesSortedByDescending(selector: (V) -> C): IntArray =
|
||||
permSortIndicesWith(compareByDescending { selector(get(it)) })
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun <V, C : Comparable<C>> Buffer<V>.sortedByDescending(selector: (V) -> C): Buffer<V> {
|
||||
val permutations = indicesSortedByDescending(selector)
|
||||
return VirtualBuffer(size) { this[permutations[it]] }
|
||||
}
|
||||
|
||||
@PerformancePitfall
|
||||
@UnstableKMathAPI
|
||||
public fun <V> Buffer<V>.permSortWith(comparator : Comparator<V>) : IntArray = _permSortWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
|
||||
public fun <V> Buffer<V>.indicesSortedWith(comparator: Comparator<V>): IntArray =
|
||||
permSortIndicesWith { i1, i2 -> comparator.compare(get(i1), get(i2)) }
|
||||
|
||||
@PerformancePitfall
|
||||
@UnstableKMathAPI
|
||||
private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray {
|
||||
if (size < 2) return IntArray(size)
|
||||
private fun <V> Buffer<V>.permSortIndicesWith(comparator: Comparator<Int>): IntArray {
|
||||
if (size < 2) return IntArray(size) { 0 }
|
||||
|
||||
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indice
|
||||
/* TODO: optimisation : keep a constant big array of indices (Ex: from 0 to 4096), then create indices
|
||||
* arrays more efficiently by copying subpart of cached one. For bigger needs, we could copy entire
|
||||
* cached array, then fill remaining indices manually. Not done for now, because:
|
||||
* 1. doing it right would require some statistics about common used buffer sizes.
|
||||
@ -53,3 +81,12 @@ private fun <V> Buffer<V>._permSortWith(comparator : Comparator<Int>) : IntArray
|
||||
*/
|
||||
return packedIndices.sortedWith(comparator).toIntArray()
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks that the [Buffer] is sorted (ascending) and throws [IllegalArgumentException] if it is not.
|
||||
*/
|
||||
public fun <T : Comparable<T>> Buffer<T>.requireSorted() {
|
||||
for (i in 0..(size - 2)) {
|
||||
require(get(i + 1) >= get(i)) { "The buffer is not sorted at index $i" }
|
||||
}
|
||||
}
|
||||
|
@ -6,14 +6,13 @@
|
||||
package space.kscience.kmath.misc
|
||||
|
||||
import space.kscience.kmath.misc.PermSortTest.Platform.*
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
import space.kscience.kmath.structures.IntBuffer
|
||||
import space.kscience.kmath.structures.asBuffer
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertContentEquals
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class PermSortTest {
|
||||
|
||||
@ -29,9 +28,9 @@ class PermSortTest {
|
||||
@Test
|
||||
fun testOnEmptyBuffer() {
|
||||
val emptyBuffer = IntBuffer(0) {it}
|
||||
var permutations = emptyBuffer.permSort()
|
||||
var permutations = emptyBuffer.indicesSorted()
|
||||
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
|
||||
permutations = emptyBuffer.permSortDescending()
|
||||
permutations = emptyBuffer.indicesSortedDescending()
|
||||
assertTrue(permutations.isEmpty(), "permutation on an empty buffer should return an empty result")
|
||||
}
|
||||
|
||||
@ -47,25 +46,25 @@ class PermSortTest {
|
||||
|
||||
@Test
|
||||
fun testPermSortBy() {
|
||||
val permutations = platforms.permSortBy { it.name }
|
||||
val permutations = platforms.indicesSortedBy { it.name }
|
||||
val expected = listOf(ANDROID, JS, JVM, NATIVE, WASM)
|
||||
assertContentEquals(expected, permutations.map { platforms[it] }, "Ascending PermSort by name")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPermSortByDescending() {
|
||||
val permutations = platforms.permSortByDescending { it.name }
|
||||
val permutations = platforms.indicesSortedByDescending { it.name }
|
||||
val expected = listOf(WASM, NATIVE, JVM, JS, ANDROID)
|
||||
assertContentEquals(expected, permutations.map { platforms[it] }, "Descending PermSort by name")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testPermSortWith() {
|
||||
var permutations = platforms.permSortWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
|
||||
var permutations = platforms.indicesSortedWith { p1, p2 -> p1.name.length.compareTo(p2.name.length) }
|
||||
val expected = listOf(JS, JVM, WASM, NATIVE, ANDROID)
|
||||
assertContentEquals(expected, permutations.map { platforms[it] }, "PermSort using custom ascending comparator")
|
||||
|
||||
permutations = platforms.permSortWith(compareByDescending { it.name.length })
|
||||
permutations = platforms.indicesSortedWith(compareByDescending { it.name.length })
|
||||
assertContentEquals(expected.reversed(), permutations.map { platforms[it] }, "PermSort using custom descending comparator")
|
||||
}
|
||||
|
||||
@ -75,7 +74,7 @@ class PermSortTest {
|
||||
println("Test randomization seed: $seed")
|
||||
|
||||
val buffer = Random(seed).buffer(bufferSize)
|
||||
val indices = buffer.permSort()
|
||||
val indices = buffer.indicesSorted()
|
||||
|
||||
assertEquals(bufferSize, indices.size)
|
||||
// Ensure no doublon is present in indices
|
||||
@ -87,7 +86,7 @@ class PermSortTest {
|
||||
assertTrue(current <= next, "Permutation indices not properly sorted")
|
||||
}
|
||||
|
||||
val descIndices = buffer.permSortDescending()
|
||||
val descIndices = buffer.indicesSortedDescending()
|
||||
assertEquals(bufferSize, descIndices.size)
|
||||
// Ensure no doublon is present in indices
|
||||
assertEquals(descIndices.toSet().size, descIndices.size)
|
||||
|
@ -21,19 +21,22 @@ readme {
|
||||
|
||||
feature(
|
||||
id = "DoubleVector",
|
||||
description = "Numpy-like operations for Buffers/Points",
|
||||
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleVector.kt"
|
||||
)
|
||||
){
|
||||
"Numpy-like operations for Buffers/Points"
|
||||
}
|
||||
|
||||
feature(
|
||||
id = "DoubleMatrix",
|
||||
description = "Numpy-like operations for 2d real structures",
|
||||
ref = "src/commonMain/kotlin/space/kscience/kmath/real/DoubleMatrix.kt"
|
||||
)
|
||||
){
|
||||
"Numpy-like operations for 2d real structures"
|
||||
}
|
||||
|
||||
feature(
|
||||
id = "grids",
|
||||
description = "Uniform grid generators",
|
||||
ref = "src/commonMain/kotlin/space/kscience/kmath/structures/grids.kt"
|
||||
)
|
||||
){
|
||||
"Uniform grid generators"
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,8 @@ kotlin.sourceSets {
|
||||
commonTest {
|
||||
dependencies {
|
||||
implementation(project(":kmath-for-real"))
|
||||
implementation(projects.kmath.kmathStat)
|
||||
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.6.0")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,11 +16,11 @@ import kotlin.math.floor
|
||||
/**
|
||||
* Multivariate histogram space for hyper-square real-field bins.
|
||||
*/
|
||||
public class DoubleHistogramSpace(
|
||||
public class DoubleHistogramGroup(
|
||||
private val lower: Buffer<Double>,
|
||||
private val upper: Buffer<Double>,
|
||||
private val binNums: IntArray = IntArray(lower.size) { 20 },
|
||||
) : IndexedHistogramSpace<Double, Double> {
|
||||
) : IndexedHistogramGroup<Double, Double> {
|
||||
|
||||
init {
|
||||
// argument checks
|
||||
@ -105,7 +105,7 @@ public class DoubleHistogramSpace(
|
||||
*/
|
||||
public fun fromRanges(
|
||||
vararg ranges: ClosedFloatingPointRange<Double>,
|
||||
): DoubleHistogramSpace = DoubleHistogramSpace(
|
||||
): DoubleHistogramGroup = DoubleHistogramGroup(
|
||||
ranges.map(ClosedFloatingPointRange<Double>::start).asBuffer(),
|
||||
ranges.map(ClosedFloatingPointRange<Double>::endInclusive).asBuffer()
|
||||
)
|
||||
@ -121,7 +121,7 @@ public class DoubleHistogramSpace(
|
||||
*/
|
||||
public fun fromRanges(
|
||||
vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>,
|
||||
): DoubleHistogramSpace = DoubleHistogramSpace(
|
||||
): DoubleHistogramGroup = DoubleHistogramGroup(
|
||||
ListBuffer(
|
||||
ranges
|
||||
.map(Pair<ClosedFloatingPointRange<Double>, Int>::first)
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.histogram
|
||||
|
||||
import space.kscience.kmath.domains.Domain1D
|
||||
import space.kscience.kmath.linear.Point
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.asSequence
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
@ -18,7 +19,7 @@ import space.kscience.kmath.structures.Buffer
|
||||
* @property standardDeviation Standard deviation of the bin value. Zero or negative if not applicable
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public class Bin1D<T : Comparable<T>, out V>(
|
||||
public data class Bin1D<T : Comparable<T>, out V>(
|
||||
public val domain: Domain1D<T>,
|
||||
override val binValue: V,
|
||||
) : Bin<T, V>, ClosedRange<T> by domain.range {
|
||||
@ -35,12 +36,15 @@ public interface Histogram1D<T : Comparable<T>, V> : Histogram<T, V, Bin1D<T, V>
|
||||
override operator fun get(point: Buffer<T>): Bin1D<T, V>? = get(point[0])
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public interface Histogram1DBuilder<in T : Any, V : Any> : HistogramBuilder<T, V> {
|
||||
/**
|
||||
* Thread safe put operation
|
||||
*/
|
||||
public fun putValue(at: T, value: V = defaultValue)
|
||||
|
||||
override fun putValue(point: Point<out T>, value: V) {
|
||||
putValue(point[0], value)
|
||||
}
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
@ -52,5 +56,5 @@ public fun Histogram1DBuilder<Double, *>.fill(array: DoubleArray): Unit =
|
||||
array.forEach(this::putValue)
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun Histogram1DBuilder<Double, *>.fill(buffer: Buffer<Double>): Unit =
|
||||
public fun <T: Any> Histogram1DBuilder<T, *>.fill(buffer: Buffer<T>): Unit =
|
||||
buffer.asSequence().forEach(this::putValue)
|
@ -28,7 +28,7 @@ public data class DomainBin<in T : Comparable<T>, out V>(
|
||||
* @param V the type of bin value
|
||||
*/
|
||||
public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
||||
public val histogramSpace: IndexedHistogramSpace<T, V>,
|
||||
public val histogramSpace: IndexedHistogramGroup<T, V>,
|
||||
public val values: StructureND<V>,
|
||||
) : Histogram<T, V, DomainBin<T, V>> {
|
||||
|
||||
@ -48,8 +48,8 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
||||
/**
|
||||
* A space for producing histograms with values in a NDStructure
|
||||
*/
|
||||
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
|
||||
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
|
||||
public interface IndexedHistogramGroup<T : Comparable<T>, V : Any> : Group<IndexedHistogram<T, V>>,
|
||||
ScaleOperations<IndexedHistogram<T, V>> {
|
||||
public val shape: Shape
|
||||
public val histogramValueAlgebra: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
|
||||
|
@ -5,8 +5,92 @@
|
||||
|
||||
package space.kscience.kmath.histogram
|
||||
|
||||
//class UniformHistogram1D(
|
||||
// public val borders: Buffer<Double>,
|
||||
// public val values: Buffer<Long>,
|
||||
//) : Histogram1D<Double, Long> {
|
||||
//}
|
||||
import space.kscience.kmath.domains.DoubleDomain1D
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.Group
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import kotlin.math.floor
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class UniformHistogram1D<V : Any>(
|
||||
public val group: UniformHistogram1DGroup<V, *>,
|
||||
public val values: Map<Int, V>,
|
||||
) : Histogram1D<Double, V> {
|
||||
|
||||
private val startPoint get() = group.startPoint
|
||||
private val binSize get() = group.binSize
|
||||
|
||||
private fun produceBin(index: Int, value: V): Bin1D<Double, V> {
|
||||
val domain = DoubleDomain1D((startPoint + index * binSize)..(startPoint + (index + 1) * binSize))
|
||||
return Bin1D(domain, value)
|
||||
}
|
||||
|
||||
override val bins: Iterable<Bin1D<Double, V>> get() = values.map { produceBin(it.key, it.value) }
|
||||
|
||||
override fun get(value: Double): Bin1D<Double, V>? {
|
||||
val index: Int = group.getIndex(value)
|
||||
val v = values[index]
|
||||
return v?.let { produceBin(index, it) }
|
||||
}
|
||||
}
|
||||
|
||||
public class UniformHistogram1DGroup<V : Any, A>(
|
||||
public val valueAlgebra: A,
|
||||
public val binSize: Double,
|
||||
public val startPoint: Double = 0.0,
|
||||
) : Group<UniformHistogram1D<V>>, ScaleOperations<UniformHistogram1D<V>> where A : Ring<V>, A : ScaleOperations<V> {
|
||||
override val zero: UniformHistogram1D<V> by lazy { UniformHistogram1D(this, emptyMap()) }
|
||||
|
||||
public fun getIndex(at: Double): Int = floor((at - startPoint) / binSize).toInt()
|
||||
|
||||
override fun add(left: UniformHistogram1D<V>, right: UniformHistogram1D<V>): UniformHistogram1D<V> = valueAlgebra {
|
||||
require(left.group == this@UniformHistogram1DGroup)
|
||||
require(right.group == this@UniformHistogram1DGroup)
|
||||
val keys = left.values.keys + right.values.keys
|
||||
UniformHistogram1D(
|
||||
this@UniformHistogram1DGroup,
|
||||
keys.associateWith { (left.values[it] ?: valueAlgebra.zero) + (right.values[it] ?: valueAlgebra.zero) }
|
||||
)
|
||||
}
|
||||
|
||||
override fun UniformHistogram1D<V>.unaryMinus(): UniformHistogram1D<V> = valueAlgebra {
|
||||
UniformHistogram1D(this@UniformHistogram1DGroup, values.mapValues { -it.value })
|
||||
}
|
||||
|
||||
override fun scale(
|
||||
a: UniformHistogram1D<V>,
|
||||
value: Double,
|
||||
): UniformHistogram1D<V> = UniformHistogram1D(
|
||||
this@UniformHistogram1DGroup,
|
||||
a.values.mapValues { valueAlgebra.scale(it.value, value) }
|
||||
)
|
||||
|
||||
public inline fun produce(block: Histogram1DBuilder<Double, V>.() -> Unit): UniformHistogram1D<V> {
|
||||
val map = HashMap<Int, V>()
|
||||
val builder = object : Histogram1DBuilder<Double, V> {
|
||||
override val defaultValue: V get() = valueAlgebra.zero
|
||||
|
||||
override fun putValue(at: Double, value: V) {
|
||||
val index = getIndex(at)
|
||||
map[index] = with(valueAlgebra) { (map[index] ?: zero) + one }
|
||||
}
|
||||
}
|
||||
builder.block()
|
||||
return UniformHistogram1D(this, map)
|
||||
}
|
||||
}
|
||||
|
||||
public fun <V : Any, A> Histogram.Companion.uniform1D(
|
||||
algebra: A,
|
||||
binSize: Double,
|
||||
startPoint: Double = 0.0,
|
||||
): UniformHistogram1DGroup<V, A> where A : Ring<V>, A : ScaleOperations<V> =
|
||||
UniformHistogram1DGroup(algebra, binSize, startPoint)
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <V : Any> UniformHistogram1DGroup<V, *>.produce(
|
||||
buffer: Buffer<Double>,
|
||||
): UniformHistogram1D<V> = produce { fill(buffer) }
|
@ -14,7 +14,7 @@ import kotlin.test.*
|
||||
internal class MultivariateHistogramTest {
|
||||
@Test
|
||||
fun testSinglePutHistogram() {
|
||||
val hSpace = DoubleHistogramSpace.fromRanges(
|
||||
val hSpace = DoubleHistogramGroup.fromRanges(
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0)
|
||||
)
|
||||
@ -29,7 +29,7 @@ internal class MultivariateHistogramTest {
|
||||
|
||||
@Test
|
||||
fun testSequentialPut() {
|
||||
val hSpace = DoubleHistogramSpace.fromRanges(
|
||||
val hSpace = DoubleHistogramGroup.fromRanges(
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0)
|
||||
@ -49,7 +49,7 @@ internal class MultivariateHistogramTest {
|
||||
|
||||
@Test
|
||||
fun testHistogramAlgebra() {
|
||||
DoubleHistogramSpace.fromRanges(
|
||||
DoubleHistogramGroup.fromRanges(
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0)
|
||||
|
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.histogram
|
||||
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import space.kscience.kmath.distributions.NormalDistribution
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.stat.RandomGenerator
|
||||
import space.kscience.kmath.stat.nextBuffer
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@OptIn(ExperimentalCoroutinesApi::class, UnstableKMathAPI::class)
|
||||
internal class UniformHistogram1DTest {
|
||||
@Test
|
||||
fun normal() = runTest {
|
||||
val generator = RandomGenerator.default(123)
|
||||
val distribution = NormalDistribution(0.0, 1.0)
|
||||
with(Histogram.uniform1D(DoubleField, 0.1)) {
|
||||
val h1 = produce(distribution.nextBuffer(generator, 10000))
|
||||
|
||||
val h2 = produce(distribution.nextBuffer(generator, 50000))
|
||||
|
||||
val h3 = h1 + h2
|
||||
|
||||
assertEquals(60000, h3.bins.sumOf { it.binValue }.toInt())
|
||||
}
|
||||
}
|
||||
}
|
@ -27,7 +27,7 @@ public interface Distribution<T : Any> : Sampler<T> {
|
||||
public companion object
|
||||
}
|
||||
|
||||
public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||
public interface Distribution1D<T : Comparable<T>> : Distribution<T> {
|
||||
/**
|
||||
* Cumulative distribution for ordered parameter (CDF)
|
||||
*/
|
||||
@ -37,7 +37,7 @@ public interface UnivariateDistribution<T : Comparable<T>> : Distribution<T> {
|
||||
/**
|
||||
* Compute probability integral in an interval
|
||||
*/
|
||||
public fun <T : Comparable<T>> UnivariateDistribution<T>.integral(from: T, to: T): Double {
|
||||
public fun <T : Comparable<T>> Distribution1D<T>.integral(from: T, to: T): Double {
|
||||
require(to > from)
|
||||
return cumulative(to) - cumulative(from)
|
||||
}
|
||||
|
@ -14,14 +14,9 @@ import space.kscience.kmath.stat.RandomGenerator
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* Implements [UnivariateDistribution] for the normal (gaussian) distribution.
|
||||
* Implements [Distribution1D] for the normal (gaussian) distribution.
|
||||
*/
|
||||
public class NormalDistribution(public val sampler: GaussianSampler) : UnivariateDistribution<Double> {
|
||||
public constructor(
|
||||
mean: Double,
|
||||
standardDeviation: Double,
|
||||
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler,
|
||||
) : this(GaussianSampler(mean, standardDeviation, normalized))
|
||||
public class NormalDistribution(public val sampler: GaussianSampler) : Distribution1D<Double> {
|
||||
|
||||
override fun probability(arg: Double): Double {
|
||||
val x1 = (arg - sampler.mean) / sampler.standardDeviation
|
||||
@ -43,3 +38,9 @@ public class NormalDistribution(public val sampler: GaussianSampler) : Univariat
|
||||
private val SQRT2 = sqrt(2.0)
|
||||
}
|
||||
}
|
||||
|
||||
public fun NormalDistribution(
|
||||
mean: Double,
|
||||
standardDeviation: Double,
|
||||
normalized: NormalizedGaussianSampler = ZigguratNormalizedGaussianSampler,
|
||||
): NormalDistribution = NormalDistribution(GaussianSampler(mean, standardDeviation, normalized))
|
||||
|
@ -67,3 +67,12 @@ public fun Sampler<Double>.sampleBuffer(generator: RandomGenerator, size: Int):
|
||||
@JvmName("sampleIntBuffer")
|
||||
public fun Sampler<Int>.sampleBuffer(generator: RandomGenerator, size: Int): Chain<Buffer<Int>> =
|
||||
sampleBuffer(generator, size, ::IntBuffer)
|
||||
|
||||
|
||||
/**
|
||||
* Samples a [Buffer] of values from this [Sampler].
|
||||
*/
|
||||
public suspend fun Sampler<Double>.nextBuffer(generator: RandomGenerator, size: Int): Buffer<Double> =
|
||||
sampleBuffer(generator, size).first()
|
||||
|
||||
//TODO add `context(RandomGenerator) Sampler.nextBuffer
|
@ -8,9 +8,9 @@ package space.kscience.kmath.stat
|
||||
import space.kscience.kmath.chains.Chain
|
||||
import space.kscience.kmath.chains.SimpleChain
|
||||
import space.kscience.kmath.distributions.Distribution
|
||||
import space.kscience.kmath.distributions.UnivariateDistribution
|
||||
import space.kscience.kmath.distributions.Distribution1D
|
||||
|
||||
public class UniformDistribution(public val range: ClosedFloatingPointRange<Double>) : UnivariateDistribution<Double> {
|
||||
public class UniformDistribution(public val range: ClosedFloatingPointRange<Double>) : Distribution1D<Double> {
|
||||
private val length: Double = range.endInclusive - range.start
|
||||
|
||||
override fun probability(arg: Double): Double = if (arg in range) 1.0 / length else 0.0
|
||||
|
Loading…
Reference in New Issue
Block a user