Some unifications for buffers and NDStructures
This commit is contained in:
parent
566883c521
commit
cdcba85ada
@ -8,7 +8,7 @@ fun main(args: Array<String>) {
|
|||||||
|
|
||||||
val n = 6000
|
val n = 6000
|
||||||
|
|
||||||
val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
|
val structure = NDStructure.build(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
|
||||||
|
|
||||||
structure.mapToBuffer { it + 1 } // warm-up
|
structure.mapToBuffer { it + 1 } // warm-up
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package scientifik.kmath.expressions
|
|||||||
|
|
||||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedField
|
||||||
|
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import kotlin.properties.ReadOnlyProperty
|
import kotlin.properties.ReadOnlyProperty
|
||||||
import kotlin.reflect.KProperty
|
import kotlin.reflect.KProperty
|
||||||
|
@ -5,7 +5,7 @@ import kotlin.math.*
|
|||||||
/**
|
/**
|
||||||
* A field for complex numbers
|
* A field for complex numbers
|
||||||
*/
|
*/
|
||||||
object ComplexField : ExtendedField<Complex> {
|
object ComplexField : ExtendedFieldOperations<Complex>, Field<Complex> {
|
||||||
override val zero: Complex = Complex(0.0, 0.0)
|
override val zero: Complex = Complex(0.0, 0.0)
|
||||||
|
|
||||||
override val one: Complex = Complex(1.0, 0.0)
|
override val one: Complex = Complex(1.0, 0.0)
|
||||||
|
@ -5,12 +5,13 @@ import kotlin.math.pow
|
|||||||
/**
|
/**
|
||||||
* Advanced Number-like field that implements basic operations
|
* Advanced Number-like field that implements basic operations
|
||||||
*/
|
*/
|
||||||
interface ExtendedField<T : Any> :
|
interface ExtendedFieldOperations<T> :
|
||||||
Field<T>,
|
FieldOperations<T>,
|
||||||
TrigonometricOperations<T>,
|
TrigonometricOperations<T>,
|
||||||
PowerOperations<T>,
|
PowerOperations<T>,
|
||||||
ExponentialOperations<T>
|
ExponentialOperations<T>
|
||||||
|
|
||||||
|
interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Real field element wrapping double.
|
* Real field element wrapping double.
|
||||||
@ -31,7 +32,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
|
|||||||
* A field for double without boxing. Does not produce appropriate field element
|
* A field for double without boxing. Does not produce appropriate field element
|
||||||
*/
|
*/
|
||||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||||
object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
object RealField : Field<Double>, ExtendedFieldOperations<Double>, Norm<Double, Double> {
|
||||||
override val zero: Double = 0.0
|
override val zero: Double = 0.0
|
||||||
override fun add(a: Double, b: Double): Double = a + b
|
override fun add(a: Double, b: Double): Double = a + b
|
||||||
override fun multiply(a: Double, b: Double): Double = a * b
|
override fun multiply(a: Double, b: Double): Double = a * b
|
||||||
|
@ -10,7 +10,7 @@ package scientifik.kmath.operations
|
|||||||
* It also allows to override behavior for optional operations
|
* It also allows to override behavior for optional operations
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
interface TrigonometricOperations<T> : Field<T> {
|
interface TrigonometricOperations<T> : FieldOperations<T> {
|
||||||
fun sin(arg: T): T
|
fun sin(arg: T): T
|
||||||
fun cos(arg: T): T
|
fun cos(arg: T): T
|
||||||
|
|
||||||
|
@ -109,6 +109,9 @@ inline class ListBuffer<T>(val list: List<T>) : Buffer<T> {
|
|||||||
|
|
||||||
fun <T> List<T>.asBuffer() = ListBuffer(this)
|
fun <T> List<T>.asBuffer() = ListBuffer(this)
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
|
inline fun <T> ListBuffer(size: Int, init: (Int) -> T) = List(size, init).asBuffer()
|
||||||
|
|
||||||
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
inline class MutableListBuffer<T>(val list: MutableList<T>) : MutableBuffer<T> {
|
||||||
|
|
||||||
override val size: Int
|
override val size: Int
|
||||||
@ -154,9 +157,12 @@ inline class DoubleBuffer(val array: DoubleArray) : MutableBuffer<Double> {
|
|||||||
override fun iterator(): Iterator<Double> = array.iterator()
|
override fun iterator(): Iterator<Double> = array.iterator()
|
||||||
|
|
||||||
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
|
override fun copy(): MutableBuffer<Double> = DoubleBuffer(array.copyOf())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("FunctionName")
|
||||||
|
inline fun DoubleBuffer(size: Int, init: (Int) -> Double) = DoubleBuffer(DoubleArray(size) { init(it) })
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Transform buffer of doubles into array for high performance operations
|
* Transform buffer of doubles into array for high performance operations
|
||||||
*/
|
*/
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.ExponentialOperations
|
import scientifik.kmath.operations.*
|
||||||
import scientifik.kmath.operations.ExtendedField
|
|
||||||
import scientifik.kmath.operations.PowerOperations
|
|
||||||
import scientifik.kmath.operations.TrigonometricOperations
|
|
||||||
|
|
||||||
|
|
||||||
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> :
|
interface ExtendedNDField<T : Any, F, N : NDStructure<T>> :
|
||||||
NDField<T, F, N>,
|
NDField<T, F, N>,
|
||||||
TrigonometricOperations<N>,
|
TrigonometricOperations<N>,
|
||||||
PowerOperations<N>,
|
PowerOperations<N>,
|
||||||
ExponentialOperations<N>
|
ExponentialOperations<N>
|
||||||
|
where F : ExtendedFieldOperations<T>, F : Field<T>
|
||||||
|
|
||||||
|
|
||||||
///**
|
///**
|
||||||
|
@ -22,6 +22,33 @@ interface NDStructure<T> {
|
|||||||
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
else -> st1.elements().all { (index, value) -> value == st2[index] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a NDStructure with explicit buffer factory
|
||||||
|
*
|
||||||
|
* Strides should be reused if possible
|
||||||
|
*/
|
||||||
|
fun <T> build(
|
||||||
|
strides: Strides,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
|
initializer: (IntArray) -> T
|
||||||
|
) =
|
||||||
|
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
||||||
|
*/
|
||||||
|
inline fun <reified T : Any> auto(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||||
|
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||||
|
|
||||||
|
fun <T> build(
|
||||||
|
shape: IntArray,
|
||||||
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
|
initializer: (IntArray) -> T
|
||||||
|
) = build(DefaultStrides(shape), bufferFactory, initializer)
|
||||||
|
|
||||||
|
inline fun <reified T : Any> auto(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||||
|
auto(DefaultStrides(shape), initializer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,34 +227,6 @@ inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a NDStructure with explicit buffer factory
|
|
||||||
*
|
|
||||||
* Strides should be reused if possible
|
|
||||||
*/
|
|
||||||
fun <T> ndStructure(
|
|
||||||
strides: Strides,
|
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
|
||||||
initializer: (IntArray) -> T
|
|
||||||
) =
|
|
||||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Inline create NDStructure with non-boxing buffer implementation if it is possible
|
|
||||||
*/
|
|
||||||
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
|
||||||
BufferNDStructure(strides, Buffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
fun <T> ndStructure(
|
|
||||||
shape: IntArray,
|
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
|
||||||
initializer: (IntArray) -> T
|
|
||||||
) =
|
|
||||||
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
|
|
||||||
|
|
||||||
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
|
||||||
inlineNDStructure(DefaultStrides(shape), initializer)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mutable ND buffer based on linear [autoBuffer]
|
* Mutable ND buffer based on linear [autoBuffer]
|
||||||
*/
|
*/
|
||||||
@ -245,33 +244,10 @@ class MutableBufferNDStructure<T>(
|
|||||||
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value)
|
override fun set(index: IntArray, value: T) = buffer.set(strides.offset(index), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* The same as [inlineNDStructure], but mutable
|
|
||||||
*/
|
|
||||||
fun <T : Any> mutableNdStructure(
|
|
||||||
strides: Strides,
|
|
||||||
bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::boxing,
|
|
||||||
initializer: (IntArray) -> T
|
|
||||||
) =
|
|
||||||
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
|
||||||
MutableBufferNDStructure(strides, MutableBuffer.auto(strides.linearSize) { i -> initializer(strides.index(i)) })
|
|
||||||
|
|
||||||
fun <T : Any> mutableNdStructure(
|
|
||||||
shape: IntArray,
|
|
||||||
bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::boxing,
|
|
||||||
initializer: (IntArray) -> T
|
|
||||||
) =
|
|
||||||
mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
|
|
||||||
|
|
||||||
inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
|
||||||
inlineMutableNdStructure(DefaultStrides(shape), initializer)
|
|
||||||
|
|
||||||
inline fun <reified T : Any> NDStructure<T>.combine(
|
inline fun <reified T : Any> NDStructure<T>.combine(
|
||||||
struct: NDStructure<T>,
|
struct: NDStructure<T>,
|
||||||
crossinline block: (T, T) -> T
|
crossinline block: (T, T) -> T
|
||||||
): NDStructure<T> {
|
): NDStructure<T> {
|
||||||
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
|
if (!this.shape.contentEquals(struct.shape)) error("Shape mismatch in structure combination")
|
||||||
return inlineNdStructure(shape) { block(this[it], struct[it]) }
|
return NDStructure.auto(shape) { block(this[it], struct[it]) }
|
||||||
}
|
}
|
@ -1,110 +1,153 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.ExtendedField
|
import scientifik.kmath.operations.ExtendedFieldOperations
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple field over linear buffers of [Double]
|
* A simple field over linear buffers of [Double]
|
||||||
*/
|
*/
|
||||||
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
object RealBufferFieldOperations : ExtendedFieldOperations<Buffer<Double>> {
|
||||||
override val zero: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 0.0 }
|
|
||||||
|
|
||||||
override val one: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 1.0 }
|
|
||||||
|
|
||||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||||
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
|
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
val bArray = b.array
|
val bArray = b.array
|
||||||
DoubleBuffer(DoubleArray(size) { aArray[it] + bArray[it] })
|
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { a[it] + b[it] })
|
DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
||||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
|
||||||
val kValue = k.toDouble()
|
val kValue = k.toDouble()
|
||||||
return if (a is DoubleBuffer) {
|
return if (a is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
DoubleBuffer(DoubleArray(size) { aArray[it] * kValue })
|
DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { a[it] * kValue })
|
DoubleBuffer(DoubleArray(a.size) { a[it] * kValue })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||||
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
|
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
val bArray = b.array
|
val bArray = b.array
|
||||||
DoubleBuffer(DoubleArray(size) { aArray[it] * bArray[it] })
|
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { a[it] * b[it] })
|
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||||
require(a.size == size) { "The size of buffer is ${a.size} but context requires $size " }
|
require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " }
|
||||||
require(b.size == size) { "The size of buffer is ${b.size} but context requires $size " }
|
|
||||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||||
val aArray = a.array
|
val aArray = a.array
|
||||||
val bArray = b.array
|
val bArray = b.array
|
||||||
DoubleBuffer(DoubleArray(size) { aArray[it] / bArray[it] })
|
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { a[it] / b[it] })
|
DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun sin(arg: Buffer<Double>): Buffer<Double> {
|
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
||||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
|
||||||
return if (arg is DoubleBuffer) {
|
return if (arg is DoubleBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
DoubleBuffer(DoubleArray(size) { sin(array[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { sin(arg[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun cos(arg: Buffer<Double>): Buffer<Double> {
|
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
||||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
|
||||||
return if (arg is DoubleBuffer) {
|
return if (arg is DoubleBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
DoubleBuffer(DoubleArray(size) { cos(array[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { cos(arg[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun power(arg: Buffer<Double>, pow: Number): Buffer<Double> {
|
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
||||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
|
||||||
return if (arg is DoubleBuffer) {
|
return if (arg is DoubleBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
DoubleBuffer(DoubleArray(size) { array[it].pow(pow.toDouble()) })
|
DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { arg[it].pow(pow.toDouble()) })
|
DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun exp(arg: Buffer<Double>): Buffer<Double> {
|
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
||||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
|
||||||
return if (arg is DoubleBuffer) {
|
return if (arg is DoubleBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
DoubleBuffer(DoubleArray(size) { exp(array[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { exp(arg[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun ln(arg: Buffer<Double>): Buffer<Double> {
|
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
||||||
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
|
||||||
return if (arg is DoubleBuffer) {
|
return if (arg is DoubleBuffer) {
|
||||||
val array = arg.array
|
val array = arg.array
|
||||||
DoubleBuffer(DoubleArray(size) { ln(array[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) })
|
||||||
} else {
|
} else {
|
||||||
DoubleBuffer(DoubleArray(size) { ln(arg[it]) })
|
DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class RealBufferField(val size: Int) : Field<Buffer<Double>>, ExtendedFieldOperations<Buffer<Double>> {
|
||||||
|
|
||||||
|
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||||
|
|
||||||
|
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
||||||
|
|
||||||
|
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.add(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Buffer<Double>, k: Number): DoubleBuffer {
|
||||||
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.multiply(a, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.multiply(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.divide(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.sin(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cos(arg: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.cos(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: Buffer<Double>, pow: Number): DoubleBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.power(arg, pow)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun exp(arg: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.exp(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun ln(arg: Buffer<Double>): DoubleBuffer {
|
||||||
|
require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" }
|
||||||
|
return RealBufferFieldOperations.ln(arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -2,9 +2,9 @@ package scientifik.kmath.histogram
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Common representation for atomic counters
|
* Common representation for atomic counters
|
||||||
|
* TODO replace with atomics
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
expect class LongCounter() {
|
expect class LongCounter() {
|
||||||
fun decrement()
|
fun decrement()
|
||||||
fun increment()
|
fun increment()
|
||||||
|
@ -2,12 +2,8 @@ package scientifik.kmath.histogram
|
|||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.structures.ArrayBuffer
|
import scientifik.kmath.structures.ArrayBuffer
|
||||||
import scientifik.kmath.structures.Buffer
|
|
||||||
import scientifik.kmath.structures.DoubleBuffer
|
import scientifik.kmath.structures.DoubleBuffer
|
||||||
|
|
||||||
|
|
||||||
typealias RealPoint = Buffer<Double>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple geometric domain
|
* A simple geometric domain
|
||||||
* TODO move to geometry module
|
* TODO move to geometry module
|
||||||
@ -42,13 +38,14 @@ interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MutableHistogram<T : Any, out B : Bin<T>> :
|
interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
||||||
Histogram<T, B> {
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Increment appropriate bin
|
* Increment appropriate bin
|
||||||
*/
|
*/
|
||||||
fun put(point: Point<out T>, weight: Double = 1.0)
|
fun putWithWeight(point: Point<out T>, weight: Double)
|
||||||
|
|
||||||
|
fun put(point: Point<out T>) = putWithWeight(point, 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
||||||
|
@ -1,64 +0,0 @@
|
|||||||
package scientifik.kmath.histogram
|
|
||||||
|
|
||||||
import scientifik.kmath.linear.Point
|
|
||||||
import scientifik.kmath.linear.Vector
|
|
||||||
import scientifik.kmath.operations.Space
|
|
||||||
import scientifik.kmath.structures.NDStructure
|
|
||||||
import scientifik.kmath.structures.asSequence
|
|
||||||
|
|
||||||
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
|
|
||||||
fun contains(vector: Point<out T>): Boolean {
|
|
||||||
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
|
||||||
val upper = center.context.run { center + sizes / 2.0 }
|
|
||||||
val lower = center.context.run { center - sizes / 2.0 }
|
|
||||||
return vector.asSequence().mapIndexed { i, value ->
|
|
||||||
value in lower[i]..upper[i]
|
|
||||||
}.all { it }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A space to perform arithmetic operations on histograms
|
|
||||||
*/
|
|
||||||
interface HistogramSpace<T : Any, B : Bin<T>, H : Histogram<T, B>> : Space<H> {
|
|
||||||
/**
|
|
||||||
* Rules for performing operations on bins
|
|
||||||
*/
|
|
||||||
val binSpace: Space<Bin<T>>
|
|
||||||
}
|
|
||||||
|
|
||||||
class PhantomBin<T : Comparable<T>>(val template: BinTemplate<T>, override val value: Number) :
|
|
||||||
Bin<T> {
|
|
||||||
|
|
||||||
override fun contains(vector: Point<out T>): Boolean = template.contains(vector)
|
|
||||||
|
|
||||||
override val dimension: Int
|
|
||||||
get() = template.center.size
|
|
||||||
|
|
||||||
override val center: Point<T>
|
|
||||||
get() = template.center
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Immutable histogram with explicit structure for content and additional external bin description.
|
|
||||||
* Bin search is slow, but full histogram algebra is supported.
|
|
||||||
* @param bins transform a template into structure index
|
|
||||||
*/
|
|
||||||
class PhantomHistogram<T : Comparable<T>>(
|
|
||||||
val bins: Map<BinTemplate<T>, IntArray>,
|
|
||||||
val data: NDStructure<Number>
|
|
||||||
) : Histogram<T, PhantomBin<T>> {
|
|
||||||
|
|
||||||
override val dimension: Int
|
|
||||||
get() = data.dimension
|
|
||||||
|
|
||||||
override fun iterator(): Iterator<PhantomBin<T>> =
|
|
||||||
bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
|
|
||||||
|
|
||||||
override fun get(point: Point<out T>): PhantomBin<T>? {
|
|
||||||
val template = bins.keys.find { it.contains(point) }
|
|
||||||
return template?.let { PhantomBin(it, data[bins[it]!!]) }
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,45 +1,66 @@
|
|||||||
package scientifik.kmath.histogram
|
package scientifik.kmath.histogram
|
||||||
|
|
||||||
|
import scientifik.kmath.linear.Point
|
||||||
import scientifik.kmath.linear.toVector
|
import scientifik.kmath.linear.toVector
|
||||||
|
import scientifik.kmath.operations.SpaceOperations
|
||||||
import scientifik.kmath.structures.*
|
import scientifik.kmath.structures.*
|
||||||
import kotlin.math.floor
|
import kotlin.math.floor
|
||||||
|
|
||||||
private operator fun RealPoint.minus(other: RealPoint) = ListBuffer((0 until size).map { get(it) - other[it] })
|
|
||||||
|
|
||||||
private inline fun <T> Buffer<out Double>.mapIndexed(crossinline mapper: (Int, Double) -> T): Sequence<T> =
|
data class BinDef<T : Comparable<T>>(val space: SpaceOperations<Point<T>>, val center: Point<T>, val sizes: Point<T>) {
|
||||||
(0 until size).asSequence().map { mapper(it, get(it)) }
|
fun contains(vector: Point<out T>): Boolean {
|
||||||
|
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
||||||
|
val upper = space.run { center + sizes / 2.0 }
|
||||||
|
val lower = space.run { center - sizes / 2.0 }
|
||||||
|
return vector.asSequence().mapIndexed { i, value ->
|
||||||
|
value in lower[i]..upper[i]
|
||||||
|
}.all { it }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultivariateBin<T : Comparable<T>>(val def: BinDef<T>, override val value: Number) : Bin<T> {
|
||||||
|
|
||||||
|
override fun contains(vector: Point<out T>): Boolean = def.contains(vector)
|
||||||
|
|
||||||
|
override val dimension: Int
|
||||||
|
get() = def.center.size
|
||||||
|
|
||||||
|
override val center: Point<T>
|
||||||
|
get() = def.center
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
|
* Uniform multivariate histogram with fixed borders. Based on NDStructure implementation with complexity of m for bin search, where m is the number of dimensions.
|
||||||
*/
|
*/
|
||||||
class FastHistogram(
|
class RealHistogram(
|
||||||
private val lower: RealPoint,
|
private val lower: Buffer<Double>,
|
||||||
private val upper: RealPoint,
|
private val upper: Buffer<Double>,
|
||||||
private val binNums: IntArray = IntArray(lower.size) { 20 }
|
private val binNums: IntArray = IntArray(lower.size) { 20 }
|
||||||
) : MutableHistogram<Double, PhantomBin<Double>> {
|
) : MutableHistogram<Double, MultivariateBin<Double>> {
|
||||||
|
|
||||||
|
|
||||||
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
|
private val strides = DefaultStrides(IntArray(binNums.size) { binNums[it] + 2 })
|
||||||
|
|
||||||
private val values: NDStructure<LongCounter> = inlineNDStructure(strides) { LongCounter() }
|
private val values: NDStructure<LongCounter> = NDStructure.auto(strides) { LongCounter() }
|
||||||
|
|
||||||
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
|
//private val weight: NDStructure<DoubleCounter?> = ndStructure(strides){null}
|
||||||
|
|
||||||
//TODO optimize binSize performance if needed
|
|
||||||
private val binSize: RealPoint =
|
override val dimension: Int get() = lower.size
|
||||||
ListBuffer((upper - lower).mapIndexed { index, value -> value / binNums[index] }.toList())
|
|
||||||
|
|
||||||
|
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||||
|
|
||||||
init {
|
init {
|
||||||
// argument checks
|
// argument checks
|
||||||
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
|
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
|
||||||
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
|
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
|
||||||
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive")
|
if ((0 until dimension).any { upper[it] - lower[it] < 0 }) error("Range for one of axis is not strictly positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override val dimension: Int get() = lower.size
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get internal [NDStructure] bin index for given axis
|
* Get internal [NDStructure] bin index for given axis
|
||||||
*/
|
*/
|
||||||
@ -61,49 +82,41 @@ class FastHistogram(
|
|||||||
return getValue(getIndex(point))
|
return getValue(getIndex(point))
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun getTemplate(index: IntArray): BinTemplate<Double> {
|
private fun getDef(index: IntArray): BinDef<Double> {
|
||||||
val center = index.mapIndexed { axis, i ->
|
val center = index.mapIndexed { axis, i ->
|
||||||
when (i) {
|
when (i) {
|
||||||
0 -> Double.NEGATIVE_INFINITY
|
0 -> Double.NEGATIVE_INFINITY
|
||||||
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||||
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
|
else -> lower[axis] + (i.toDouble() - 0.5) * binSize[axis]
|
||||||
}
|
}
|
||||||
}.toVector()
|
}.asBuffer()
|
||||||
return BinTemplate(center, binSize)
|
return BinDef(RealBufferFieldOperations, center, binSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> {
|
fun getDef(point: Buffer<out Double>): BinDef<Double> {
|
||||||
return getTemplate(getIndex(point))
|
return getDef(getIndex(point))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
|
override fun get(point: Buffer<out Double>): MultivariateBin<Double>? {
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
return PhantomBin(getTemplate(index), getValue(index))
|
return MultivariateBin(getDef(index), getValue(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun put(point: Buffer<out Double>, weight: Double) {
|
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||||
if (weight != 1.0) TODO("Implement weighting")
|
if (weight != 1.0) TODO("Implement weighting")
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
values[index].increment()
|
values[index].increment()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<PhantomBin<Double>> = values.elements().map { (index, value) ->
|
override fun iterator(): Iterator<MultivariateBin<Double>> = values.elements().map { (index, value) ->
|
||||||
PhantomBin(getTemplate(index), value.sum())
|
MultivariateBin(getDef(index), value.sum())
|
||||||
}.iterator()
|
}.iterator()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
||||||
*/
|
*/
|
||||||
fun asNDStructure(): NDStructure<Number> {
|
fun asNDStructure(): NDStructure<Number> {
|
||||||
return inlineNdStructure(this.values.shape) { values[it].sum() }
|
return NDStructure.auto(this.values.shape) { values[it].sum() }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a phantom lightweight immutable copy of this histogram
|
|
||||||
*/
|
|
||||||
fun asPhantomHistogram(): PhantomHistogram<Double> {
|
|
||||||
val binTemplates = values.elements().associate { (index, _) -> getTemplate(index) to index }
|
|
||||||
return PhantomHistogram(binTemplates, asNDStructure())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
@ -117,8 +130,8 @@ class FastHistogram(
|
|||||||
*)
|
*)
|
||||||
*```
|
*```
|
||||||
*/
|
*/
|
||||||
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram {
|
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): RealHistogram {
|
||||||
return FastHistogram(
|
return RealHistogram(
|
||||||
ranges.map { it.start }.toVector(),
|
ranges.map { it.start }.toVector(),
|
||||||
ranges.map { it.endInclusive }.toVector()
|
ranges.map { it.endInclusive }.toVector()
|
||||||
)
|
)
|
||||||
@ -133,8 +146,8 @@ class FastHistogram(
|
|||||||
*)
|
*)
|
||||||
*```
|
*```
|
||||||
*/
|
*/
|
||||||
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram {
|
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): RealHistogram {
|
||||||
return FastHistogram(
|
return RealHistogram(
|
||||||
ListBuffer(ranges.map { it.first.start }),
|
ListBuffer(ranges.map { it.first.start }),
|
||||||
ListBuffer(ranges.map { it.first.endInclusive }),
|
ListBuffer(ranges.map { it.first.endInclusive }),
|
||||||
ranges.map { it.second }.toIntArray()
|
ranges.map { it.second }.toIntArray()
|
@ -1,6 +1,6 @@
|
|||||||
package scietifik.kmath.histogram
|
package scietifik.kmath.histogram
|
||||||
|
|
||||||
import scientifik.kmath.histogram.FastHistogram
|
import scientifik.kmath.histogram.RealHistogram
|
||||||
import scientifik.kmath.histogram.fill
|
import scientifik.kmath.histogram.fill
|
||||||
import scientifik.kmath.histogram.put
|
import scientifik.kmath.histogram.put
|
||||||
import scientifik.kmath.linear.Vector
|
import scientifik.kmath.linear.Vector
|
||||||
@ -13,7 +13,7 @@ import kotlin.test.assertTrue
|
|||||||
class MultivariateHistogramTest {
|
class MultivariateHistogramTest {
|
||||||
@Test
|
@Test
|
||||||
fun testSinglePutHistogram() {
|
fun testSinglePutHistogram() {
|
||||||
val histogram = FastHistogram.fromRanges(
|
val histogram = RealHistogram.fromRanges(
|
||||||
(-1.0..1.0),
|
(-1.0..1.0),
|
||||||
(-1.0..1.0)
|
(-1.0..1.0)
|
||||||
)
|
)
|
||||||
@ -26,7 +26,7 @@ class MultivariateHistogramTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSequentialPut() {
|
fun testSequentialPut() {
|
||||||
val histogram = FastHistogram.fromRanges(
|
val histogram = RealHistogram.fromRanges(
|
||||||
(-1.0..1.0),
|
(-1.0..1.0),
|
||||||
(-1.0..1.0),
|
(-1.0..1.0),
|
||||||
(-1.0..1.0)
|
(-1.0..1.0)
|
||||||
|
@ -59,7 +59,7 @@ class UnivariateHistogram private constructor(private val factory: (Double) -> U
|
|||||||
(get(value) ?: createBin(value)).inc()
|
(get(value) ?: createBin(value)).inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun put(point: Buffer<out Double>, weight: Double) {
|
override fun putWithWeight(point: Buffer<out Double>, weight: Double) {
|
||||||
if (weight != 1.0) TODO("Implement weighting")
|
if (weight != 1.0) TODO("Implement weighting")
|
||||||
put(point[0])
|
put(point[0])
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user