Some unifications for buffers and NDStructures
This commit is contained in:
parent
566883c521
commit
cdcba85ada
@ -8,7 +8,7 @@ fun main(args: Array<String>) {
|
||||
|
||||
val n = 6000
|
||||
|
||||
val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
|
||||
val structure = NDStructure.build(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
|
||||
|
||||
structure.mapToBuffer { it + 1 } // warm-up
|
||||
|
||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.expressions
|
||||
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||
import scientifik.kmath.operations.Field
|
||||
import kotlin.properties.ReadOnlyProperty
|
||||
import kotlin.reflect.KProperty
|
||||
|
@ -5,7 +5,7 @@ import kotlin.math.*
|
||||
/**
|
||||
* A field for complex numbers
|
||||
*/
|
||||
object ComplexField : ExtendedField<Complex> {
|
||||
object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||
override val zero: Complex = Complex(0.0, 0.0)
|
||||
|
||||
override val one: Complex = Complex(1.0, 0.0)
|
||||
|
@ -5,12 +5,13 @@ import kotlin.math.pow
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations
|
||||
*/
|
||||
interface ExtendedField<T : Any> :
|
||||
Field<T>,
|
||||
interface ExtendedFieldOperations<T> :
|
||||
FieldOperations<T>,
|
||||
TrigonometricOperations<T>,
|
||||
PowerOperations<T>,
|
||||
ExponentialOperations<T>
|
||||
|
||||
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>
|
||||
|
||||
/**
|
||||
* Real field element wrapping double.
|
||||
@ -31,7 +32,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
||||
* A field for double without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
||||
override val zero: Double = 0.0
|
||||
override fun add(a: Double, b: Double): Double = a + b
|
||||
override fun multiply(a: Double, b: Double): Double = a * b
|
||||
|
@ -10,7 +10,7 @@ package scientifik.kmath.operations
|
||||
* It also allows to override behavior for optional operations
|
||||
*
|
||||
*/
|
||||
interface TrigonometricOperations<T> : Field<T> {
|
||||
interface TrigonometricOperations<T> : FieldOperations<T> {
|
||||
fun sin(arg: T): T
|
||||
fun cos(arg: T): T
|
||||
|
||||
|
@ -109,6 +109,9 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
||||
|
||||
fun <T> List<T>.asBuffer() = ListBuffer(this)
|
||||
|
||||
@Suppress("FunctionName")
|
||||
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
|
||||
|
||||
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
||||
|
||||
override val size: Int
|
||||
@ -154,9 +157,12 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
||||
override fun iterator(): Iterator<Double> = array.iterator()
|
||||
|
||||
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
|
||||
|
||||
}
|
||||
|
||||
@Suppress("FunctionName")
|
||||
inline fun DoubleBuffer(size: Int, init: (Int) -> Double) = DoubleBuffer(DoubleArray(size) { init(it) })
|
||||
|
||||
|
||||
/**
|
||||
* Transform buffer of doubles into array for high performance operations
|
||||
*/
|
||||
|
@ -1,16 +1,14 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.ExponentialOperations
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.PowerOperations
|
||||
import scientifik.kmath.operations.TrigonometricOperations
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
|
||||
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> :
|
||||
interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
|
||||
NDField<T, F, N>,
|
||||
TrigonometricOperations<N>,
|
||||
PowerOperations<N>,
|
||||
ExponentialOperations<N>
|
||||
where F : ExtendedFieldOperations<T>, F : Field<T>
|
||||
|
||||
|
||||
///**
|
||||
|
@ -22,6 +22,33 @@ interface NDStructure<T> {
|
||||
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a NDStructure with explicit buffer factory
|
||||
*
|
||||
* Strides should be reused if possible
|
||||
*/
|
||||
fun <T> build(
|
||||
strides: Strides,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
/**
|
||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
fun <T> build(
|
||||
shape: IntArray,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
auto(DefaultStrides(shape), initializer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,34 +227,6 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a NDStructure with explicit buffer factory
|
||||
*
|
||||
* Strides should be reused if possible
|
||||
*/
|
||||
fun <T> ndStructure(
|
||||
strides: Strides,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
/**
|
||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||
*/
|
||||
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
fun <T> ndStructure(
|
||||
shape: IntArray,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
inlineNDStructure(DefaultStrides(shape), initializer)
|
||||
|
||||
/**
|
||||
* Mutable ND buffer based on linear [autoBuffer]
|
||||
*/
|
||||
@ -245,33 +244,10 @@ class MutableBufferNDStructure<T>(
|
||||
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value)
|
||||
}
|
||||
|
||||
/**
|
||||
* The same as [inlineNDStructure], but mutable
|
||||
*/
|
||||
fun <T : Any> mutableNdStructure(
|
||||
strides: Strides,
|
||||
bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
MutableBufferNDStructure(strides, MutableBuffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
fun <T : Any> mutableNdStructure(
|
||||
shape: IntArray,
|
||||
bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::boxing,
|
||||
initializer: (IntArray) -> T
|
||||
) =
|
||||
mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
inlineMutableNdStructure(DefaultStrides(shape), initializer)
|
||||
|
||||
inline fun <reified T : Any> NDStructure<T>.combine(
|
||||
struct: NDStructure<T>,
|
||||
crossinline block: (T, T) -> T
|
||||
): NDStructure<T> {
|
||||
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
|
||||
return inlineNdStructure(shape) { block(this[it], struct[it]) }
|
||||
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||
}
|
@ -1,110 +1,153 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.ExtendedField
|
||||
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||
import scientifik.kmath.operations.Field
|
||||
import kotlin.math.*
|
||||
|
||||
|
||||
/**
|
||||
* A simple field over linear buffers of [Double]
|
||||
*/
|
||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||
override val zero: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 0.0 }
|
||||
|
||||
override val one: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 1.0 }
|
||||
|
||||
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
||||
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
|
||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(size) { aArray[it] + bArray[it] })
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { a[it] + b[it] })
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||
}
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
||||
val kValue = k.toDouble()
|
||||
return if (a is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
DoubleBuffer(DoubleArray(size) { aArray[it] * kValue })
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { a[it] * kValue })
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] * kValue })
|
||||
}
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
||||
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
|
||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(size) { aArray[it] * bArray[it] })
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { a[it] * b[it] })
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
}
|
||||
}
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
||||
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
|
||||
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(size) { aArray[it] / bArray[it] })
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { a[it] / b[it] })
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||
}
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): Buffer<Double> {
|
||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(size) { sin(array[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { sin(arg[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
||||
}
|
||||
}
|
||||
|
||||
override fun cos(arg: Buffer<Double>): Buffer<Double> {
|
||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(size) { cos(array[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { cos(arg[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
||||
}
|
||||
}
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): Buffer<Double> {
|
||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(size) { array[it].pow(pow.toDouble()) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { arg[it].pow(pow.toDouble()) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||
}
|
||||
}
|
||||
|
||||
override fun exp(arg: Buffer<Double>): Buffer<Double> {
|
||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(size) { exp(array[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { exp(arg[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
||||
}
|
||||
}
|
||||
|
||||
override fun ln(arg: Buffer<Double>): Buffer<Double> {
|
||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
||||
return if (arg is DoubleBuffer) {
|
||||
val array = arg.array
|
||||
DoubleBuffer(DoubleArray(size) { ln(array[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||
} else {
|
||||
DoubleBuffer(DoubleArray(size) { ln(arg[it]) })
|
||||
DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class RealBufferField(val size: Int) : Field<Buffer<Double>>, ExtendedFieldOperations<Buffer<Double>> {
|
||||
|
||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||
|
||||
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
||||
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.add(a, b)
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.multiply(a, k)
|
||||
}
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.multiply(a, b)
|
||||
}
|
||||
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.divide(a, b)
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.sin(arg)
|
||||
}
|
||||
|
||||
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.cos(arg)
|
||||
}
|
||||
|
||||
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.power(arg, pow)
|
||||
}
|
||||
|
||||
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.exp(arg)
|
||||
}
|
||||
|
||||
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
||||
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.ln(arg)
|
||||
}
|
||||
|
||||
}
|
@ -2,9 +2,9 @@ package scientifik.kmath.histogram
|
||||
|
||||
/*
|
||||
* Common representation for atomic counters
|
||||
* TODO replace with atomics
|
||||
*/
|
||||
|
||||
|
||||
expect class LongCounter() {
|
||||
fun decrement()
|
||||
fun increment()
|
||||
|
@ -2,12 +2,8 @@ package scientifik.kmath.histogram
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.structures.ArrayBuffer
|
||||
import scientifik.kmath.structures.Buffer
|
||||
import scientifik.kmath.structures.DoubleBuffer
|
||||
|
||||
|
||||
typealias RealPoint = Buffer<Double>
|
||||
|
||||
/**
|
||||
* A simple geometric domain
|
||||
* TODO move to geometry module
|
||||
@ -42,13 +38,14 @@ interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
|
||||
|
||||
}
|
||||
|
||||
interface MutableHistogram<T : Any, out B : Bin<T>> :
|
||||
Histogram<T, B> {
|
||||
interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
||||
|
||||
/**
|
||||
* Increment appropriate bin
|
||||
*/
|
||||
fun put(point: Point<out T>, weight: Double = 1.0)
|
||||
fun putWithWeight(point: Point<out T>, weight: Double)
|
||||
|
||||
fun put(point: Point<out T>) = putWithWeight(point, 1.0)
|
||||
}
|
||||
|
||||
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
||||
|
@ -1,64 +0,0 @@
|
||||
package scientifik.kmath.histogram
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.linear.Vector
|
||||
import scientifik.kmath.operations.Space
|
||||
import scientifik.kmath.structures.NDStructure
|
||||
import scientifik.kmath.structures.asSequence
|
||||
|
||||
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
|
||||
fun contains(vector: Point<out T>): Boolean {
|
||||
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
||||
val upper = center.context.run { center + sizes / 2.0 }
|
||||
val lower = center.context.run { center - sizes / 2.0 }
|
||||
return vector.asSequence().mapIndexed { i, value ->
|
||||
value in lower[i]..upper[i]
|
||||
}.all { it }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A space to perform arithmetic operations on histograms
|
||||
*/
|
||||
interface HistogramSpace<T : Any, B : Bin<T>, H : Histogram<T, B>> : Space<H> {
|
||||
/**
|
||||
* Rules for performing operations on bins
|
||||
*/
|
||||
val binSpace: Space<Bin<T>>
|
||||
}
|
||||
|
||||
class PhantomBin<T : Comparable<T>>(val template: BinTemplate<T>, override val value: Number) :
|
||||
Bin<T> {
|
||||
|
||||
override fun contains(vector: Point<out T>): Boolean = template.contains(vector)
|
||||
|
||||
override val dimension: Int
|
||||
get() = template.center.size
|
||||
|
||||
override val center: Point<T>
|
||||
get() = template.center
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Immutable histogram with explicit structure for content and additional external bin description.
|
||||
* Bin search is slow, but full histogram algebra is supported.
|
||||
* @param bins transform a template into structure index
|
||||
*/
|
||||
class PhantomHistogram<T : Comparable<T>>(
|
||||
val bins: Map<BinTemplate<T>, IntArray>,
|
||||
val data: NDStructure<Number>
|
||||
) : Histogram<T, PhantomBin<T>> {
|
||||
|
||||
override val dimension: Int
|
||||
get() = data.dimension
|
||||
|
||||
override fun iterator(): Iterator<PhantomBin<T>> =
|
||||
bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
|
||||
|
||||
override fun get(point: Point<out T>): PhantomBin<T>? {
|
||||
val template = bins.keys.find { it.contains(point) }
|
||||
return template?.let { PhantomBin(it, data[bins[it]!!]) }
|
||||
}
|
||||
|
||||
}
|
@ -1,45 +1,66 @@
|
||||
package scientifik.kmath.histogram
|
||||
|
||||
import scientifik.kmath.linear.Point
|
||||
import scientifik.kmath.linear.toVector
|
||||
import scientifik.kmath.operations.SpaceOperations
|
||||
import scientifik.kmath.structures.*
|
||||
import kotlin.math.floor
|
||||
|
||||
private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] })
|
||||
|
||||
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> =
|
||||
(0 until size).asSequence().map { mapper(it, get(it)) }
|
||||
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
|
||||
fun contains(vector: Point<out T>): Boolean {
|
||||
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
||||
val upper = space.run { center + sizes / 2.0 }
|
||||
val lower = space.run { center - sizes / 2.0 }
|
||||
return vector.asSequence().mapIndexed { i, value ->
|
||||
value in lower[i]..upper[i]
|
||||
}.all { it }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
||||
|
||||
override fun contains(vector: Point<out T>): Boolean = def.contains(vector)
|
||||
|
||||
override val dimension: Int
|
||||
get() = def.center.size
|
||||
|
||||
override val center: Point<T>
|
||||
get() = def.center
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
|
||||
*/
|
||||
class FastHistogram(
|
||||
private val lower: RealPoint,
|
||||
private val upper: RealPoint,
|
||||
class RealHistogram(
|
||||
private val lower: Buffer<Double>,
|
||||
private val upper: Buffer<Double>,
|
||||
private val binNums: IntArray = IntArray(lower.size) { 20 }
|
||||
) : MutableHistogram<Double, PhantomBin<Double>> {
|
||||
) : MutableHistogram<Double, MultivariateBin<Double>> {
|
||||
|
||||
|
||||
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
|
||||
|
||||
private val values: NDStructure<LongCounter> = inlineNDStructure(strides) { LongCounter() }
|
||||
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
||||
|
||||
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
|
||||
|
||||
//TODO optimize binSize performance if needed
|
||||
private val binSize: RealPoint =
|
||||
ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
|
||||
|
||||
override val dimension: Int get() = lower.size
|
||||
|
||||
|
||||
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||
|
||||
init {
|
||||
// argument checks
|
||||
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
|
||||
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
|
||||
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive")
|
||||
if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive")
|
||||
}
|
||||
|
||||
|
||||
override val dimension: Int get() = lower.size
|
||||
|
||||
|
||||
/**
|
||||
* Get internal [NDStructure] bin index for given axis
|
||||
*/
|
||||
@ -61,49 +82,41 @@ class FastHistogram(
|
||||
return getValue(getIndex(point))
|
||||
}
|
||||
|
||||
private fun getTemplate(index: IntArray): BinTemplate<Double> {
|
||||
private fun getDef(index: IntArray): BinDef<Double> {
|
||||
val center = index.mapIndexed { axis, i ->
|
||||
when (i) {
|
||||
0 -> Double.NEGATIVE_INFINITY
|
||||
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
|
||||
}
|
||||
}.toVector()
|
||||
return BinTemplate(center, binSize)
|
||||
}.asBuffer()
|
||||
return BinDef(RealBufferFieldOperations, center, binSize)
|
||||
}
|
||||
|
||||
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> {
|
||||
return getTemplate(getIndex(point))
|
||||
fun getDef(point: Buffer<out Double>): BinDef<Double> {
|
||||
return getDef(getIndex(point))
|
||||
}
|
||||
|
||||
override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
|
||||
override fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
|
||||
val index = getIndex(point)
|
||||
return PhantomBin(getTemplate(index), getValue(index))
|
||||
return MultivariateBin(getDef(index), getValue(index))
|
||||
}
|
||||
|
||||
override fun put(point: Buffer<out Double>, weight: Double) {
|
||||
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||
if (weight != 1.0) TODO("Implement weighting")
|
||||
val index = getIndex(point)
|
||||
values[index].increment()
|
||||
}
|
||||
|
||||
override fun iterator(): Iterator<PhantomBin<Double>> = values.elements().map { (index, value) ->
|
||||
PhantomBin(getTemplate(index), value.sum())
|
||||
override fun iterator(): Iterator<MultivariateBin<Double>> = values.elements().map { (index, value) ->
|
||||
MultivariateBin(getDef(index), value.sum())
|
||||
}.iterator()
|
||||
|
||||
/**
|
||||
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
||||
*/
|
||||
fun asNDStructure(): NDStructure<Number> {
|
||||
return inlineNdStructure(this.values.shape) { values[it].sum() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a phantom lightweight immutable copy of this histogram
|
||||
*/
|
||||
fun asPhantomHistogram(): PhantomHistogram<Double> {
|
||||
val binTemplates = values.elements().associate { (index, _) -> getTemplate(index) to index }
|
||||
return PhantomHistogram(binTemplates, asNDStructure())
|
||||
return NDStructure.auto(this.values.shape) { values[it].sum() }
|
||||
}
|
||||
|
||||
companion object {
|
||||
@ -117,8 +130,8 @@ class FastHistogram(
|
||||
*)
|
||||
*```
|
||||
*/
|
||||
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram {
|
||||
return FastHistogram(
|
||||
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram {
|
||||
return RealHistogram(
|
||||
ranges.map { it.start }.toVector(),
|
||||
ranges.map { it.endInclusive }.toVector()
|
||||
)
|
||||
@ -133,8 +146,8 @@ class FastHistogram(
|
||||
*)
|
||||
*```
|
||||
*/
|
||||
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram {
|
||||
return FastHistogram(
|
||||
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram {
|
||||
return RealHistogram(
|
||||
ListBuffer(ranges.map { it.first.start }),
|
||||
ListBuffer(ranges.map { it.first.endInclusive }),
|
||||
ranges.map { it.second }.toIntArray()
|
@ -1,6 +1,6 @@
|
||||
package scietifik.kmath.histogram
|
||||
|
||||
import scientifik.kmath.histogram.FastHistogram
|
||||
import scientifik.kmath.histogram.RealHistogram
|
||||
import scientifik.kmath.histogram.fill
|
||||
import scientifik.kmath.histogram.put
|
||||
import scientifik.kmath.linear.Vector
|
||||
@ -13,7 +13,7 @@ import kotlin.test.assertTrue
|
||||
class MultivariateHistogramTest {
|
||||
@Test
|
||||
fun testSinglePutHistogram() {
|
||||
val histogram = FastHistogram.fromRanges(
|
||||
val histogram = RealHistogram.fromRanges(
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0)
|
||||
)
|
||||
@ -26,7 +26,7 @@ class MultivariateHistogramTest {
|
||||
|
||||
@Test
|
||||
fun testSequentialPut() {
|
||||
val histogram = FastHistogram.fromRanges(
|
||||
val histogram = RealHistogram.fromRanges(
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0),
|
||||
(-1.0..1.0)
|
||||
|
@ -59,7 +59,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
||||
(get(value) ?: createBin(value)).inc()
|
||||
}
|
||||
|
||||
override fun put(point: Buffer<out Double>, weight: Double) {
|
||||
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||
if (weight != 1.0) TODO("Implement weighting")
|
||||
put(point[0])
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user