Broaden scope of VectorSpace to include Points

This commit is contained in:
Alexander Nozik 2018-11-23 14:08:28 +03:00
parent 5cd6301f45
commit 8f5276bdcb
5 changed files with 151 additions and 123 deletions

View File

@ -1,6 +1,5 @@
package scientifik.kmath.histogram
import scientifik.kmath.operations.Space
import scientifik.kmath.structures.ArrayBuffer
import scientifik.kmath.structures.Buffer
@ -55,14 +54,4 @@ fun <T: Any> MutableHistogram<T,*>.fill(sequence: Iterable<Point<T>>) = sequence
/**
* Pass a sequence builder into histogram
*/
fun <T: Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable())
/**
* A space to perform arithmetic operations on histograms
*/
interface HistogramSpace<T: Any, B : Bin<T>, H : Histogram<T,B>> : Space<H> {
/**
* Rules for performing operations on bins
*/
val binSpace: Space<Bin<T>>
}
fun <T: Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable())

View File

@ -1,23 +1,38 @@
package scientifik.kmath.histogram
import scientifik.kmath.linear.RealVector
import scientifik.kmath.linear.Vector
import scientifik.kmath.operations.Space
import scientifik.kmath.structures.NDStructure
class BinTemplate(val center: RealVector, val sizes: RealVector) {
fun contains(vector: Point<out Double>): Boolean {
class BinTemplate<T : Comparable<T>>(val center: Vector<T>, val sizes: Vector<T>) {
fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
return vector.asSequence().mapIndexed { i, value -> value in (center[i] - sizes[i] / 2)..(center[i] + sizes[i] / 2) }.all { it }
val upper = center + sizes/2.0
val lower = center - sizes/2.0
return vector.asSequence().mapIndexed { i, value ->
value in lower[i]..upper[i]
}.all { it }
}
}
class PhantomBin(val template: BinTemplate, override val value: Number) : Bin<Double> {
/**
* A space to perform arithmetic operations on histograms
*/
interface HistogramSpace<T : Any, B : Bin<T>, H : Histogram<T, B>> : Space<H> {
/**
* Rules for performing operations on bins
*/
val binSpace: Space<Bin<T>>
}
override fun contains(vector: Point<out Double>): Boolean = template.contains(vector)
class PhantomBin<T : Comparable<T>>(val template: BinTemplate<T>, override val value: Number) : Bin<T> {
override fun contains(vector: Point<out T>): Boolean = template.contains(vector)
override val dimension: Int
get() = template.center.size
override val center: Point<Double>
override val center: Point<T>
get() = template.center
}
@ -26,19 +41,19 @@ class PhantomBin(val template: BinTemplate, override val value: Number) : Bin<Do
* Immutable histogram with explicit structure for content and additional external bin description.
* Bin search is slow, but full histogram algebra is supported.
*/
class PhantomHistogram(
val bins: Map<BinTemplate, IntArray>,
class PhantomHistogram<T : Comparable<T>>(
val bins: Map<BinTemplate<T>, IntArray>,
val data: NDStructure<Double>
) : Histogram<Double, PhantomBin> {
) : Histogram<T, PhantomBin<T>> {
override val dimension: Int
get() = data.dimension
override fun iterator(): Iterator<PhantomBin> {
return bins.asSequence().map {entry-> PhantomBin(entry.key,data[entry.value]) }.iterator()
override fun iterator(): Iterator<PhantomBin<T>> {
return bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
}
override fun get(point: Point<out Double>): PhantomBin? {
override fun get(point: Point<out T>): PhantomBin<T>? {
val template = bins.keys.find { it.contains(point) }
return template?.let { PhantomBin(it, data[bins[it]!!]) }
}

View File

@ -1,7 +1,9 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.*
import scientifik.kmath.structures.*
import scientifik.kmath.structures.ExtendedNDField
import scientifik.kmath.structures.GenericNDField
import scientifik.kmath.structures.NDField
/**
* The space for linear elements. Supports scalar product alongside with standard linear operations.
@ -139,49 +141,7 @@ interface Matrix<T : Any> : SpaceElement<Matrix<T>, MatrixSpace<T>> {
}
/**
* A linear space for vectors
*/
abstract class VectorSpace<T : Any>(val size: Int, val field: Field<T>) : Space<Vector<T>> {
abstract fun produce(initializer: (Int) -> T): Vector<T>
override val zero: Vector<T> by lazy { produce { field.zero } }
override fun add(a: Vector<T>, b: Vector<T>): Vector<T> = produce { with(field) { a[it] + b[it] } }
override fun multiply(a: Vector<T>, k: Double): Vector<T> = produce { with(field) { a[it] * k } }
}
interface Vector<T : Any> : SpaceElement<Vector<T>, VectorSpace<T>>, Buffer<T>, Iterable<T> {
override val size: Int get() = context.size
companion object {
/**
* Create vector with custom field
*/
fun <T : Any> of(size: Int, field: Field<T>, initializer: (Int) -> T) =
ArrayVector(ArrayVectorSpace(size, field), initializer)
/**
* Create vector of [Double]
*/
fun ofReal(size: Int, initializer: (Int) -> Double) =
ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer)
fun ofReal(vararg point: Double) = point.toVector()
fun equals(v1: Vector<*>, v2: Vector<*>): Boolean {
if (v1 === v2) return true
if (v1.context != v2.context) return false
for (i in 0 until v2.size) {
if (v1[i] != v2[i]) return false
}
return true
}
}
}
typealias NDFieldFactory<T> = (IntArray) -> NDField<T>
@ -210,61 +170,7 @@ class ArrayMatrixSpace<T : Any>(
}
}
class ArrayVectorSpace<T : Any>(
size: Int,
field: Field<T>,
val ndFactory: NDFieldFactory<T> = genericNDFieldFactory(field)
) : VectorSpace<T>(size, field) {
val ndField by lazy {
ndFactory(intArrayOf(size))
}
override fun produce(initializer: (Int) -> T): Vector<T> = ArrayVector(this, initializer)
}
/**
* Member of [ArrayMatrixSpace] which wraps 2-D array
*/
class ArrayMatrix<T : Any> internal constructor(override val context: ArrayMatrixSpace<T>, val element: NDElement<T>) : Matrix<T> {
constructor(context: ArrayMatrixSpace<T>, initializer: (Int, Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0], list[1]) })
override val rows: Int get() = context.rows
override val columns: Int get() = context.columns
override fun get(i: Int, j: Int): T {
return element[i, j]
}
override val self: ArrayMatrix<T> get() = this
}
class ArrayVector<T : Any> internal constructor(override val context: ArrayVectorSpace<T>, val element: NDElement<T>) : Vector<T> {
constructor(context: ArrayVectorSpace<T>, initializer: (Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0]) })
init {
if (context.size != element.shape[0]) {
error("Array dimension mismatch")
}
}
override fun get(index: Int): T {
return element[index]
}
override val self: ArrayVector<T> get() = this
override fun iterator(): Iterator<T> = (0 until size).map { element[it] }.iterator()
override fun copy(): ArrayVector<T> = ArrayVector(context, element)
override fun toString(): String = this.joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() }
}
typealias RealVector = Vector<Double>
/**
* A group of methods to resolve equation A dot X = B, where A and B are matrices or vectors

View File

@ -0,0 +1,118 @@
package scientifik.kmath.linear
import scientifik.kmath.histogram.Point
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.NDElement
import scientifik.kmath.structures.get
/**
* A linear space for vectors.
* Could be used on any point-like structure
*/
abstract class VectorSpace<T : Any>(val size: Int, val field: Field<T>) : Space<Point<T>> {
abstract fun produce(initializer: (Int) -> T): Vector<T>
override val zero: Vector<T> by lazy { produce { field.zero } }
override fun add(a: Point<T>, b: Point<T>): Vector<T> = produce { with(field) { a[it] + b[it] } }
override fun multiply(a: Point<T>, k: Double): Vector<T> = produce { with(field) { a[it] * k } }
}
/**
* A point coupled to the linear space
*/
interface Vector<T : Any> : SpaceElement<Point<T>, VectorSpace<T>>, Point<T>, Iterable<T> {
override val size: Int get() = context.size
override operator fun plus(b: Point<T>): Vector<T> = context.add(self, b)
override operator fun minus(b: Point<T>): Vector<T> = context.add(self, context.multiply(b, -1.0))
override operator fun times(k: Number): Vector<T> = context.multiply(self, k.toDouble())
override operator fun div(k: Number): Vector<T> = context.multiply(self, 1.0 / k.toDouble())
companion object {
/**
* Create vector with custom field
*/
fun <T : Any> of(size: Int, field: Field<T>, initializer: (Int) -> T) =
ArrayVector(ArrayVectorSpace(size, field), initializer)
/**
* Create vector of [Double]
*/
fun ofReal(size: Int, initializer: (Int) -> Double) =
ArrayVector(ArrayVectorSpace(size, DoubleField, realNDFieldFactory), initializer)
fun ofReal(vararg point: Double) = point.toVector()
fun equals(v1: Vector<*>, v2: Vector<*>): Boolean {
if (v1 === v2) return true
if (v1.context != v2.context) return false
for (i in 0 until v2.size) {
if (v1[i] != v2[i]) return false
}
return true
}
}
}
class ArrayVectorSpace<T : Any>(
size: Int,
field: Field<T>,
val ndFactory: NDFieldFactory<T> = genericNDFieldFactory(field)
) : VectorSpace<T>(size, field) {
val ndField by lazy {
ndFactory(intArrayOf(size))
}
override fun produce(initializer: (Int) -> T): Vector<T> = ArrayVector(this, initializer)
}
/**
* Member of [ArrayMatrixSpace] which wraps 2-D array
*/
class ArrayMatrix<T : Any> internal constructor(override val context: ArrayMatrixSpace<T>, val element: NDElement<T>) : Matrix<T> {
constructor(context: ArrayMatrixSpace<T>, initializer: (Int, Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0], list[1]) })
override val rows: Int get() = context.rows
override val columns: Int get() = context.columns
override fun get(i: Int, j: Int): T {
return element[i, j]
}
override val self: ArrayMatrix<T> get() = this
}
class ArrayVector<T : Any> internal constructor(override val context: VectorSpace<T>, val element: NDElement<T>) : Vector<T> {
constructor(context: ArrayVectorSpace<T>, initializer: (Int) -> T) : this(context, context.ndField.produce { list -> initializer(list[0]) })
init {
if (context.size != element.shape[0]) {
error("Array dimension mismatch")
}
}
override fun get(index: Int): T {
return element[index]
}
override val self: ArrayVector<T> get() = this
override fun iterator(): Iterator<T> = (0 until size).map { element[it] }.iterator()
override fun copy(): ArrayVector<T> = ArrayVector(context, element)
override fun toString(): String = this.joinToString(prefix = "[", postfix = "]", separator = ", ") { it.toString() }
}
typealias RealVector = Vector<Double>

View File

@ -2,7 +2,7 @@ package scientifik.kmath.structures
/**
* A generic linear buffer for both primitives and objects
* A generic random access structure for both primitives and objects
*/
interface Buffer<T> : Iterable<T> {