Some unifications for buffers and NDStructures

This commit is contained in:
Alexander Nozik 2019-02-14 20:31:08 +03:00
parent 566883c521
commit cdcba85ada
15 changed files with 190 additions and 219 deletions

View File

@ -8,7 +8,7 @@ fun main(args: Array<String>) {
val n = 6000 val n = 6000
val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 } val structure = NDStructure.build(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
structure.mapToBuffer { it + 1 } // warm-up structure.mapToBuffer { it + 1 } // warm-up

View File

@ -2,6 +2,7 @@ package scientifik.kmath.expressions
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.ExtendedFieldOperations
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import kotlin.properties.ReadOnlyProperty import kotlin.properties.ReadOnlyProperty
import kotlin.reflect.KProperty import kotlin.reflect.KProperty

View File

@ -5,7 +5,7 @@ import kotlin.math.*
/** /**
* A field for complex numbers * A field for complex numbers
*/ */
object ComplexField : ExtendedField<Complex> { object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
override val zero: Complex = Complex(0.0, 0.0) override val zero: Complex = Complex(0.0, 0.0)
override val one: Complex = Complex(1.0, 0.0) override val one: Complex = Complex(1.0, 0.0)

View File

@ -5,12 +5,13 @@ import kotlin.math.pow
/** /**
* Advanced Number-like field that implements basic operations * Advanced Number-like field that implements basic operations
*/ */
interface ExtendedField<T : Any> : interface ExtendedFieldOperations<T> :
Field<T>, FieldOperations<T>,
TrigonometricOperations<T>, TrigonometricOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T> ExponentialOperations<T>
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>
/** /**
* Real field element wrapping double. * Real field element wrapping double.
@ -31,7 +32,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
* A field for double without boxing. Does not produce appropriate field element * A field for double without boxing. Does not produce appropriate field element
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
object RealField : ExtendedField<Double>, Norm<Double, Double> { object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
override val zero: Double = 0.0 override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, b: Double): Double = a * b override fun multiply(a: Double, b: Double): Double = a * b

View File

@ -10,7 +10,7 @@ package scientifik.kmath.operations
* It also allows to override behavior for optional operations * It also allows to override behavior for optional operations
* *
*/ */
interface TrigonometricOperations<T> : Field<T> { interface TrigonometricOperations<T> : FieldOperations<T> {
fun sin(arg: T): T fun sin(arg: T): T
fun cos(arg: T): T fun cos(arg: T): T

View File

@ -109,6 +109,9 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
fun <T> List<T>.asBuffer() = ListBuffer(this) fun <T> List<T>.asBuffer() = ListBuffer(this)
@Suppress("FunctionName")
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> { inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
override val size: Int override val size: Int
@ -154,9 +157,12 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
override fun iterator(): Iterator<Double> = array.iterator() override fun iterator(): Iterator<Double> = array.iterator()
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf()) override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
} }
@Suppress("FunctionName")
inline fun DoubleBuffer(size: Int, init: (Int) -> Double) = DoubleBuffer(DoubleArray(size) { init(it) })
/** /**
* Transform buffer of doubles into array for high performance operations * Transform buffer of doubles into array for high performance operations
*/ */

View File

@ -1,16 +1,14 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.ExponentialOperations import scientifik.kmath.operations.*
import scientifik.kmath.operations.ExtendedField
import scientifik.kmath.operations.PowerOperations
import scientifik.kmath.operations.TrigonometricOperations
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> : interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
NDField<T, F, N>, NDField<T, F, N>,
TrigonometricOperations<N>, TrigonometricOperations<N>,
PowerOperations<N>, PowerOperations<N>,
ExponentialOperations<N> ExponentialOperations<N>
where F : ExtendedFieldOperations<T>, F : Field<T>
///** ///**

View File

@ -22,6 +22,33 @@ interface NDStructure<T> {
else -> st1.elements().all { (index, value) -> value == st2[index] } else -> st1.elements().all { (index, value) -> value == st2[index] }
} }
} }
/**
* Create a NDStructure with explicit buffer factory
*
* Strides should be reused if possible
*/
fun <T> build(
strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T
) =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/**
* Inline create NDStructure with non-boxing buffer implementation if it is possible
*/
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> build(
shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T
) = build(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
auto(DefaultStrides(shape), initializer)
} }
} }
@ -200,34 +227,6 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
} }
} }
/**
* Create a NDStructure with explicit buffer factory
*
* Strides should be reused if possible
*/
fun <T> ndStructure(
strides: Strides,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T
) =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/**
* Inline create NDStructure with non-boxing buffer implementation if it is possible
*/
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T> ndStructure(
shape: IntArray,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
initializer: (IntArray) -> T
) =
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
inlineNDStructure(DefaultStrides(shape), initializer)
/** /**
* Mutable ND buffer based on linear [autoBuffer] * Mutable ND buffer based on linear [autoBuffer]
*/ */
@ -245,33 +244,10 @@ class MutableBufferNDStructure<T>(
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value) override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value)
} }
/**
* The same as [inlineNDStructure], but mutable
*/
fun <T : Any> mutableNdStructure(
strides: Strides,
bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::boxing,
initializer: (IntArray) -> T
) =
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
MutableBufferNDStructure(strides, MutableBuffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
fun <T : Any> mutableNdStructure(
shape: IntArray,
bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::boxing,
initializer: (IntArray) -> T
) =
mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
inlineMutableNdStructure(DefaultStrides(shape), initializer)
inline fun <reified T : Any> NDStructure<T>.combine( inline fun <reified T : Any> NDStructure<T>.combine(
struct: NDStructure<T>, struct: NDStructure<T>,
crossinline block: (T, T) -> T crossinline block: (T, T) -> T
): NDStructure<T> { ): NDStructure<T> {
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination") if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
return inlineNdStructure(shape) { block(this[it], struct[it]) } return NDStructure.auto(shape) { block(this[it], struct[it]) }
} }

View File

@ -1,110 +1,153 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.ExtendedField import scientifik.kmath.operations.ExtendedFieldOperations
import scientifik.kmath.operations.Field
import kotlin.math.* import kotlin.math.*
/** /**
* A simple field over linear buffers of [Double] * A simple field over linear buffers of [Double]
*/ */
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> { object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
override val zero: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 0.0 }
override val one: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 1.0 }
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (a is DoubleBuffer && b is DoubleBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
DoubleBuffer(DoubleArray(size) { aArray[it] + bArray[it] }) DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
} else { } else {
DoubleBuffer(DoubleArray(size) { a[it] + b[it] }) DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
} }
} }
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer { override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
val kValue = k.toDouble() val kValue = k.toDouble()
return if (a is DoubleBuffer) { return if (a is DoubleBuffer) {
val aArray = a.array val aArray = a.array
DoubleBuffer(DoubleArray(size) { aArray[it] * kValue }) DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue })
} else { } else {
DoubleBuffer(DoubleArray(size) { a[it] * kValue }) DoubleBuffer(DoubleArray(a.size) { a[it] * kValue })
} }
} }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (a is DoubleBuffer && b is DoubleBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
DoubleBuffer(DoubleArray(size) { aArray[it] * bArray[it] }) DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
} else { } else {
DoubleBuffer(DoubleArray(size) { a[it] * b[it] }) DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
} }
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " } require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (a is DoubleBuffer && b is DoubleBuffer) {
val aArray = a.array val aArray = a.array
val bArray = b.array val bArray = b.array
DoubleBuffer(DoubleArray(size) { aArray[it] / bArray[it] }) DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
} else { } else {
DoubleBuffer(DoubleArray(size) { a[it] / b[it] }) DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
} }
} }
override fun sin(arg: Buffer<Double>): Buffer<Double> { override fun sin(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) { return if (arg is DoubleBuffer) {
val array = arg.array val array = arg.array
DoubleBuffer(DoubleArray(size) { sin(array[it]) }) DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) })
} else { } else {
DoubleBuffer(DoubleArray(size) { sin(arg[it]) }) DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) })
} }
} }
override fun cos(arg: Buffer<Double>): Buffer<Double> { override fun cos(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) { return if (arg is DoubleBuffer) {
val array = arg.array val array = arg.array
DoubleBuffer(DoubleArray(size) { cos(array[it]) }) DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) })
} else { } else {
DoubleBuffer(DoubleArray(size) { cos(arg[it]) }) DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) })
} }
} }
override fun power(arg: Buffer<Double>, pow: Number): Buffer<Double> { override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) { return if (arg is DoubleBuffer) {
val array = arg.array val array = arg.array
DoubleBuffer(DoubleArray(size) { array[it].pow(pow.toDouble()) }) DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
} else { } else {
DoubleBuffer(DoubleArray(size) { arg[it].pow(pow.toDouble()) }) DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
} }
} }
override fun exp(arg: Buffer<Double>): Buffer<Double> { override fun exp(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) { return if (arg is DoubleBuffer) {
val array = arg.array val array = arg.array
DoubleBuffer(DoubleArray(size) { exp(array[it]) }) DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) })
} else { } else {
DoubleBuffer(DoubleArray(size) { exp(arg[it]) }) DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) })
} }
} }
override fun ln(arg: Buffer<Double>): Buffer<Double> { override fun ln(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) { return if (arg is DoubleBuffer) {
val array = arg.array val array = arg.array
DoubleBuffer(DoubleArray(size) { ln(array[it]) }) DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) })
} else { } else {
DoubleBuffer(DoubleArray(size) { ln(arg[it]) }) DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) })
} }
} }
}
class RealBufferField(val size: Int) : Field<Buffer<Double>>, ExtendedFieldOperations<Buffer<Double>> {
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.add(a, b)
}
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.multiply(a, k)
}
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.multiply(a, b)
}
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
return RealBufferFieldOperations.divide(a, b)
}
override fun sin(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.sin(arg)
}
override fun cos(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.cos(arg)
}
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.power(arg, pow)
}
override fun exp(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.exp(arg)
}
override fun ln(arg: Buffer<Double>): DoubleBuffer {
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
return RealBufferFieldOperations.ln(arg)
}
} }

View File

@ -2,9 +2,9 @@ package scientifik.kmath.histogram
/* /*
* Common representation for atomic counters * Common representation for atomic counters
* TODO replace with atomics
*/ */
expect class LongCounter() { expect class LongCounter() {
fun decrement() fun decrement()
fun increment() fun increment()

View File

@ -2,12 +2,8 @@ package scientifik.kmath.histogram
import scientifik.kmath.linear.Point import scientifik.kmath.linear.Point
import scientifik.kmath.structures.ArrayBuffer import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.Buffer
import scientifik.kmath.structures.DoubleBuffer import scientifik.kmath.structures.DoubleBuffer
typealias RealPoint = Buffer<Double>
/** /**
* A simple geometric domain * A simple geometric domain
* TODO move to geometry module * TODO move to geometry module
@ -42,13 +38,14 @@ interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
} }
interface MutableHistogram<T : Any, out B : Bin<T>> : interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
Histogram<T, B> {
/** /**
* Increment appropriate bin * Increment appropriate bin
*/ */
fun put(point: Point<out T>, weight: Double = 1.0) fun putWithWeight(point: Point<out T>, weight: Double)
fun put(point: Point<out T>) = putWithWeight(point, 1.0)
} }
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point)) fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))

View File

@ -1,64 +0,0 @@
package scientifik.kmath.histogram
import scientifik.kmath.linear.Point
import scientifik.kmath.linear.Vector
import scientifik.kmath.operations.Space
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.asSequence
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
val upper = center.context.run { center + sizes / 2.0 }
val lower = center.context.run { center - sizes / 2.0 }
return vector.asSequence().mapIndexed { i, value ->
value in lower[i]..upper[i]
}.all { it }
}
}
/**
* A space to perform arithmetic operations on histograms
*/
interface HistogramSpace<T : Any, B : Bin<T>, H : Histogram<T, B>> : Space<H> {
/**
* Rules for performing operations on bins
*/
val binSpace: Space<Bin<T>>
}
class PhantomBin<T : Comparable<T>>(val template: BinTemplate<T>, override val value: Number) :
Bin<T> {
override fun contains(vector: Point<out T>): Boolean = template.contains(vector)
override val dimension: Int
get() = template.center.size
override val center: Point<T>
get() = template.center
}
/**
* Immutable histogram with explicit structure for content and additional external bin description.
* Bin search is slow, but full histogram algebra is supported.
* @param bins transform a template into structure index
*/
class PhantomHistogram<T : Comparable<T>>(
val bins: Map<BinTemplate<T>, IntArray>,
val data: NDStructure<Number>
) : Histogram<T, PhantomBin<T>> {
override val dimension: Int
get() = data.dimension
override fun iterator(): Iterator<PhantomBin<T>> =
bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
override fun get(point: Point<out T>): PhantomBin<T>? {
val template = bins.keys.find { it.contains(point) }
return template?.let { PhantomBin(it, data[bins[it]!!]) }
}
}

View File

@ -1,45 +1,66 @@
package scientifik.kmath.histogram package scientifik.kmath.histogram
import scientifik.kmath.linear.Point
import scientifik.kmath.linear.toVector import scientifik.kmath.linear.toVector
import scientifik.kmath.operations.SpaceOperations
import scientifik.kmath.structures.* import scientifik.kmath.structures.*
import kotlin.math.floor import kotlin.math.floor
private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] })
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> = data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
(0 until size).asSequence().map { mapper(it, get(it)) } fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
val upper = space.run { center + sizes / 2.0 }
val lower = space.run { center - sizes / 2.0 }
return vector.asSequence().mapIndexed { i, value ->
value in lower[i]..upper[i]
}.all { it }
}
}
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
override fun contains(vector: Point<out T>): Boolean = def.contains(vector)
override val dimension: Int
get() = def.center.size
override val center: Point<T>
get() = def.center
}
/** /**
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions. * Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
*/ */
class FastHistogram( class RealHistogram(
private val lower: RealPoint, private val lower: Buffer<Double>,
private val upper: RealPoint, private val upper: Buffer<Double>,
private val binNums: IntArray = IntArray(lower.size) { 20 } private val binNums: IntArray = IntArray(lower.size) { 20 }
) : MutableHistogram<Double, PhantomBin<Double>> { ) : MutableHistogram<Double, MultivariateBin<Double>> {
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 }) private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
private val values: NDStructure<LongCounter> = inlineNDStructure(strides) { LongCounter() } private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null} //private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
//TODO optimize binSize performance if needed
private val binSize: RealPoint = override val dimension: Int get() = lower.size
ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
init { init {
// argument checks // argument checks
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
if (lower.size != binNums.size) error("Dimension mismatch in bin count.") if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive") if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive")
} }
override val dimension: Int get() = lower.size
/** /**
* Get internal [NDStructure] bin index for given axis * Get internal [NDStructure] bin index for given axis
*/ */
@ -61,49 +82,41 @@ class FastHistogram(
return getValue(getIndex(point)) return getValue(getIndex(point))
} }
private fun getTemplate(index: IntArray): BinTemplate<Double> { private fun getDef(index: IntArray): BinDef<Double> {
val center = index.mapIndexed { axis, i -> val center = index.mapIndexed { axis, i ->
when (i) { when (i) {
0 -> Double.NEGATIVE_INFINITY 0 -> Double.NEGATIVE_INFINITY
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis] else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
} }
}.toVector() }.asBuffer()
return BinTemplate(center, binSize) return BinDef(RealBufferFieldOperations, center, binSize)
} }
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> { fun getDef(point: Buffer<out Double>): BinDef<Double> {
return getTemplate(getIndex(point)) return getDef(getIndex(point))
} }
override fun get(point: Buffer<out Double>): PhantomBin<Double>? { override fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
val index = getIndex(point) val index = getIndex(point)
return PhantomBin(getTemplate(index), getValue(index)) return MultivariateBin(getDef(index), getValue(index))
} }
override fun put(point: Buffer<out Double>, weight: Double) { override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
if (weight != 1.0) TODO("Implement weighting") if (weight != 1.0) TODO("Implement weighting")
val index = getIndex(point) val index = getIndex(point)
values[index].increment() values[index].increment()
} }
override fun iterator(): Iterator<PhantomBin<Double>> = values.elements().map { (index, value) -> override fun iterator(): Iterator<MultivariateBin<Double>> = values.elements().map { (index, value) ->
PhantomBin(getTemplate(index), value.sum()) MultivariateBin(getDef(index), value.sum())
}.iterator() }.iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun asNDStructure(): NDStructure<Number> { fun asNDStructure(): NDStructure<Number> {
return inlineNdStructure(this.values.shape) { values[it].sum() } return NDStructure.auto(this.values.shape) { values[it].sum() }
}
/**
* Create a phantom lightweight immutable copy of this histogram
*/
fun asPhantomHistogram(): PhantomHistogram<Double> {
val binTemplates = values.elements().associate { (index, _) -> getTemplate(index) to index }
return PhantomHistogram(binTemplates, asNDStructure())
} }
companion object { companion object {
@ -117,8 +130,8 @@ class FastHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram { fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram {
return FastHistogram( return RealHistogram(
ranges.map { it.start }.toVector(), ranges.map { it.start }.toVector(),
ranges.map { it.endInclusive }.toVector() ranges.map { it.endInclusive }.toVector()
) )
@ -133,8 +146,8 @@ class FastHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram { fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram {
return FastHistogram( return RealHistogram(
ListBuffer(ranges.map { it.first.start }), ListBuffer(ranges.map { it.first.start }),
ListBuffer(ranges.map { it.first.endInclusive }), ListBuffer(ranges.map { it.first.endInclusive }),
ranges.map { it.second }.toIntArray() ranges.map { it.second }.toIntArray()

View File

@ -1,6 +1,6 @@
package scietifik.kmath.histogram package scietifik.kmath.histogram
import scientifik.kmath.histogram.FastHistogram import scientifik.kmath.histogram.RealHistogram
import scientifik.kmath.histogram.fill import scientifik.kmath.histogram.fill
import scientifik.kmath.histogram.put import scientifik.kmath.histogram.put
import scientifik.kmath.linear.Vector import scientifik.kmath.linear.Vector
@ -13,7 +13,7 @@ import kotlin.test.assertTrue
class MultivariateHistogramTest { class MultivariateHistogramTest {
@Test @Test
fun testSinglePutHistogram() { fun testSinglePutHistogram() {
val histogram = FastHistogram.fromRanges( val histogram = RealHistogram.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)
) )
@ -26,7 +26,7 @@ class MultivariateHistogramTest {
@Test @Test
fun testSequentialPut() { fun testSequentialPut() {
val histogram = FastHistogram.fromRanges( val histogram = RealHistogram.fromRanges(
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0), (-1.0..1.0),
(-1.0..1.0) (-1.0..1.0)

View File

@ -59,7 +59,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
(get(value) ?: createBin(value)).inc() (get(value) ?: createBin(value)).inc()
} }
override fun put(point: Buffer<out Double>, weight: Double) { override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
if (weight != 1.0) TODO("Implement weighting") if (weight != 1.0) TODO("Implement weighting")
put(point[0]) put(point[0])
} }