forked from kscience/kmath
use consistent code style and simplify
This commit is contained in:
parent
622a6a7756
commit
55ce9b4754
@ -17,9 +17,8 @@ interface ExpressionContext<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
|
||||||
override fun invoke(arguments: Map<String, T>): T {
|
override fun invoke(arguments: Map<String, T>): T =
|
||||||
return arguments[name] ?: default ?: error("The parameter not found: $name")
|
arguments[name] ?: default ?: error("Parameter not found: $name")
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
internal class ConstantExpression<T>(val value: T) : Expression<T> {
|
||||||
|
@ -5,16 +5,16 @@ package scientifik.kmath.histogram
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
expect class LongCounter(){
|
expect class LongCounter() {
|
||||||
fun decrement()
|
fun decrement()
|
||||||
fun increment()
|
fun increment()
|
||||||
fun reset()
|
fun reset()
|
||||||
fun sum(): Long
|
fun sum(): Long
|
||||||
fun add(l:Long)
|
fun add(l: Long)
|
||||||
}
|
}
|
||||||
|
|
||||||
expect class DoubleCounter(){
|
expect class DoubleCounter() {
|
||||||
fun reset()
|
fun reset()
|
||||||
fun sum(): Double
|
fun sum(): Double
|
||||||
fun add(d: Double)
|
fun add(d: Double)
|
||||||
}
|
}
|
@ -31,7 +31,7 @@ class FastHistogram(
|
|||||||
// argument checks
|
// argument checks
|
||||||
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
|
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
|
||||||
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
|
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
|
||||||
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive")
|
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one axis is not strictly positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -41,23 +41,19 @@ class FastHistogram(
|
|||||||
/**
|
/**
|
||||||
* Get internal [NDStructure] bin index for given axis
|
* Get internal [NDStructure] bin index for given axis
|
||||||
*/
|
*/
|
||||||
private fun getIndex(axis: Int, value: Double): Int {
|
private fun getIndex(axis: Int, value: Double): Int =
|
||||||
return when {
|
when {
|
||||||
value >= upper[axis] -> binNums[axis] + 1 // overflow
|
value >= upper[axis] -> binNums[axis] + 1 // overflow
|
||||||
value < lower[axis] -> 0 // underflow
|
value < lower[axis] -> 0 // underflow
|
||||||
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
|
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
|
private fun getIndex(point: Buffer<out Double>): IntArray =
|
||||||
|
IntArray(dimension) { getIndex(it, point[it]) }
|
||||||
|
|
||||||
private fun getValue(index: IntArray): Long {
|
private fun getValue(index: IntArray): Long = values[index].sum()
|
||||||
return values[index].sum()
|
|
||||||
}
|
|
||||||
|
|
||||||
fun getValue(point: Buffer<out Double>): Long {
|
fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
||||||
return getValue(getIndex(point))
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun getTemplate(index: IntArray): BinTemplate<Double> {
|
private fun getTemplate(index: IntArray): BinTemplate<Double> {
|
||||||
val center = index.mapIndexed { axis, i ->
|
val center = index.mapIndexed { axis, i ->
|
||||||
@ -70,9 +66,7 @@ class FastHistogram(
|
|||||||
return BinTemplate(center, binSize)
|
return BinTemplate(center, binSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> {
|
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> = getTemplate(getIndex(point))
|
||||||
return getTemplate(getIndex(point))
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
|
override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
@ -85,16 +79,16 @@ class FastHistogram(
|
|||||||
values[index].increment()
|
values[index].increment()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun iterator(): Iterator<PhantomBin<Double>> = values.elements().map { (index, value) ->
|
override fun iterator(): Iterator<PhantomBin<Double>> =
|
||||||
PhantomBin(getTemplate(index), value.sum())
|
values.elements().map { (index, value) ->
|
||||||
}.iterator()
|
PhantomBin(getTemplate(index), value.sum())
|
||||||
|
}.iterator()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
* Convert this histogram into NDStructure containing bin values but not bin descriptions
|
||||||
*/
|
*/
|
||||||
fun asNDStructure(): NDStructure<Number> {
|
fun asNDStructure(): NDStructure<Number> =
|
||||||
return inlineNdStructure(this.values.shape) { values[it].sum() }
|
inlineNdStructure(this.values.shape) { values[it].sum() }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a phantom lightweight immutable copy of this histogram
|
* Create a phantom lightweight immutable copy of this histogram
|
||||||
@ -115,9 +109,8 @@ class FastHistogram(
|
|||||||
*)
|
*)
|
||||||
*```
|
*```
|
||||||
*/
|
*/
|
||||||
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram {
|
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram =
|
||||||
return FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector())
|
FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector())
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use it like
|
* Use it like
|
||||||
@ -128,13 +121,12 @@ class FastHistogram(
|
|||||||
*)
|
*)
|
||||||
*```
|
*```
|
||||||
*/
|
*/
|
||||||
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram {
|
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram =
|
||||||
return FastHistogram(
|
FastHistogram(
|
||||||
ListBuffer(ranges.map { it.first.start }),
|
ListBuffer(ranges.map { it.first.start }),
|
||||||
ListBuffer(ranges.map { it.first.endInclusive }),
|
ListBuffer(ranges.map { it.first.endInclusive }),
|
||||||
ranges.map { it.second }.toIntArray()
|
ranges.map { it.second }.toIntArray()
|
||||||
)
|
)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -12,7 +12,7 @@ typealias RealPoint = Buffer<Double>
|
|||||||
* A simple geometric domain
|
* A simple geometric domain
|
||||||
* TODO move to geometry module
|
* TODO move to geometry module
|
||||||
*/
|
*/
|
||||||
interface Domain<T: Any> {
|
interface Domain<T : Any> {
|
||||||
operator fun contains(vector: Point<out T>): Boolean
|
operator fun contains(vector: Point<out T>): Boolean
|
||||||
val dimension: Int
|
val dimension: Int
|
||||||
}
|
}
|
||||||
@ -20,7 +20,7 @@ interface Domain<T: Any> {
|
|||||||
/**
|
/**
|
||||||
* The bin in the histogram. The histogram is by definition always done in the real space
|
* The bin in the histogram. The histogram is by definition always done in the real space
|
||||||
*/
|
*/
|
||||||
interface Bin<T: Any> : Domain<T> {
|
interface Bin<T : Any> : Domain<T> {
|
||||||
/**
|
/**
|
||||||
* The value of this bin
|
* The value of this bin
|
||||||
*/
|
*/
|
||||||
@ -28,7 +28,7 @@ interface Bin<T: Any> : Domain<T> {
|
|||||||
val center: Point<T>
|
val center: Point<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
|
interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Find existing bin, corresponding to given coordinates
|
* Find existing bin, corresponding to given coordinates
|
||||||
@ -42,7 +42,7 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{
|
interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Increment appropriate bin
|
* Increment appropriate bin
|
||||||
@ -50,14 +50,14 @@ interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{
|
|||||||
fun put(point: Point<out T>, weight: Double = 1.0)
|
fun put(point: Point<out T>, weight: Double = 1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T: Any> MutableHistogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point))
|
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
|
||||||
|
|
||||||
fun MutableHistogram<Double,*>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
fun MutableHistogram<Double, *>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
|
||||||
fun MutableHistogram<Double,*>.put(vararg point: Double) = put(DoubleBuffer(point))
|
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point))
|
||||||
|
|
||||||
fun <T: Any> MutableHistogram<T,*>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
|
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pass a sequence builder into histogram
|
* Pass a sequence builder into histogram
|
||||||
*/
|
*/
|
||||||
fun <T: Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable())
|
fun <T : Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable())
|
@ -8,8 +8,8 @@ import scientifik.kmath.structures.asSequence
|
|||||||
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
|
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
|
||||||
fun contains(vector: Point<out T>): Boolean {
|
fun contains(vector: Point<out T>): Boolean {
|
||||||
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
|
||||||
val upper = center.context.run { center + sizes / 2.0}
|
val upper = center.context.run { center + sizes / 2.0 }
|
||||||
val lower = center.context.run {center - sizes / 2.0}
|
val lower = center.context.run { center - sizes / 2.0 }
|
||||||
return vector.asSequence().mapIndexed { i, value ->
|
return vector.asSequence().mapIndexed { i, value ->
|
||||||
value in lower[i]..upper[i]
|
value in lower[i]..upper[i]
|
||||||
}.all { it }
|
}.all { it }
|
||||||
@ -51,9 +51,8 @@ class PhantomHistogram<T : Comparable<T>>(
|
|||||||
override val dimension: Int
|
override val dimension: Int
|
||||||
get() = data.dimension
|
get() = data.dimension
|
||||||
|
|
||||||
override fun iterator(): Iterator<PhantomBin<T>> {
|
override fun iterator(): Iterator<PhantomBin<T>> =
|
||||||
return bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
|
bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
|
||||||
}
|
|
||||||
|
|
||||||
override fun get(point: Point<out T>): PhantomBin<T>? {
|
override fun get(point: Point<out T>): PhantomBin<T>? {
|
||||||
val template = bins.keys.find { it.contains(point) }
|
val template = bins.keys.find { it.contains(point) }
|
||||||
|
@ -86,7 +86,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In-place transformation for [MutableNDArray], using given transformation for each element
|
* In-place transformation for [MutableNDStructure], using given transformation for each element
|
||||||
*/
|
*/
|
||||||
operator fun <T> MutableNDStructure<T>.set(i: Int, j: Int, value: T) {
|
operator fun <T> MutableNDStructure<T>.set(i: Int, j: Int, value: T) {
|
||||||
this[intArrayOf(i, j)] = value
|
this[intArrayOf(i, j)] = value
|
||||||
@ -174,9 +174,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
|
|||||||
* @return the pivot permutation vector
|
* @return the pivot permutation vector
|
||||||
* @see .getP
|
* @see .getP
|
||||||
*/
|
*/
|
||||||
fun getPivot(): IntArray {
|
fun getPivot(): IntArray = pivot.copyOf()
|
||||||
return pivot.copyOf()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,8 +28,8 @@ fun List<Double>.toVector() = Vector.real(this.size) { this[it] }
|
|||||||
/**
|
/**
|
||||||
* Convert matrix to vector if it is possible
|
* Convert matrix to vector if it is possible
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> {
|
fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> =
|
||||||
return if (this.numCols == 1) {
|
if (this.numCols == 1) {
|
||||||
// if (this is ArrayMatrix) {
|
// if (this is ArrayMatrix) {
|
||||||
// //Reuse existing underlying array
|
// //Reuse existing underlying array
|
||||||
// ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array)
|
// ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array)
|
||||||
@ -37,9 +37,8 @@ fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> {
|
|||||||
// //Generic vector
|
// //Generic vector
|
||||||
// vector(rows, context.field) { get(it, 0) }
|
// vector(rows, context.field) { get(it, 0) }
|
||||||
// }
|
// }
|
||||||
Vector.generic(numRows, context.ring) { get(it, 0) }
|
Vector.generic(numRows, context.ring) { get(it, 0) }
|
||||||
} else error("Can't convert matrix with more than one column to vector")
|
} else error("Can't convert matrix with more than one column to vector")
|
||||||
}
|
|
||||||
|
|
||||||
fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
|
fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
|
||||||
// val context = StructureMatrixContext(size, 1, context.space)
|
// val context = StructureMatrixContext(size, 1, context.space)
|
||||||
@ -56,9 +55,8 @@ fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
object VectorL2Norm : Norm<Vector<out Number, *>, Double> {
|
object VectorL2Norm : Norm<Vector<out Number, *>, Double> {
|
||||||
override fun norm(arg: Vector<out Number, *>): Double {
|
override fun norm(arg: Vector<out Number, *>): Double =
|
||||||
return kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
|
kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
typealias RealVector = Vector<Double, DoubleField>
|
typealias RealVector = Vector<Double, DoubleField>
|
||||||
|
@ -33,9 +33,11 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
|
|||||||
|
|
||||||
val one get() = produce { i, j -> if (i == j) ring.one else ring.zero }
|
val one get() = produce { i, j -> if (i == j) ring.one else ring.zero }
|
||||||
|
|
||||||
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
|
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> =
|
||||||
|
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
|
||||||
|
|
||||||
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
|
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> =
|
||||||
|
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
/**
|
/**
|
||||||
@ -120,21 +122,24 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
|
|||||||
|
|
||||||
private val strides = DefaultStrides(shape)
|
private val strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> {
|
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> =
|
||||||
return if (rows == rowNum && columns == colNum) {
|
if (rows == rowNum && columns == colNum) {
|
||||||
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
|
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
|
||||||
StructureMatrix(this, structure)
|
StructureMatrix(this, structure)
|
||||||
} else {
|
} else {
|
||||||
val context = StructureMatrixSpace(rows, columns, ring, bufferFactory)
|
val context = StructureMatrixSpace(rows, columns, ring, bufferFactory)
|
||||||
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
|
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
|
||||||
StructureMatrix(context, structure)
|
StructureMatrix(context, structure)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class StructureMatrix<T : Any, R : Ring<T>>(override val context: StructureMatrixSpace<T, R>, val structure: NDStructure<T>) : Matrix<T, R> {
|
data class StructureMatrix<T : Any, R : Ring<T>>(
|
||||||
|
override val context: StructureMatrixSpace<T, R>,
|
||||||
|
val structure: NDStructure<T>
|
||||||
|
) : Matrix<T, R> {
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) {
|
if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) {
|
||||||
error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found")
|
error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found")
|
||||||
|
@ -33,9 +33,8 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
|||||||
/**
|
/**
|
||||||
* Non-boxing double vector space
|
* Non-boxing double vector space
|
||||||
*/
|
*/
|
||||||
fun real(size: Int): BufferVectorSpace<Double, DoubleField> {
|
fun real(size: Int): BufferVectorSpace<Double, DoubleField> =
|
||||||
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) }
|
realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) }
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structured vector space with custom buffer
|
* A structured vector space with custom buffer
|
||||||
@ -69,16 +68,12 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T,
|
|||||||
VectorSpace.buffered(size, field).produce(initializer)
|
VectorSpace.buffered(size, field).produce(initializer)
|
||||||
|
|
||||||
fun real(size: Int, initializer: (Int) -> Double) = VectorSpace.real(size).produce(initializer)
|
fun real(size: Int, initializer: (Int) -> Double) = VectorSpace.real(size).produce(initializer)
|
||||||
fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce{elements[it]}
|
fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce { elements[it] }
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data class BufferVectorSpace<T : Any, S : Space<T>>(
|
data class BufferVectorSpace<T : Any, S : Space<T>>(override val size: Int, override val space: S, val bufferFactory: BufferFactory<T>) : VectorSpace<T, S> {
|
||||||
override val size: Int,
|
|
||||||
override val space: S,
|
|
||||||
val bufferFactory: BufferFactory<T>
|
|
||||||
) : VectorSpace<T, S> {
|
|
||||||
override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer))
|
override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,9 +86,7 @@ data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun get(index: Int): T {
|
override fun get(index: Int): T = buffer[index]
|
||||||
return buffer[index]
|
|
||||||
}
|
|
||||||
|
|
||||||
override val self: BufferVector<T, S> get() = this
|
override val self: BufferVector<T, S> get() = this
|
||||||
|
|
||||||
|
@ -32,28 +32,28 @@ fun <T, R> List<T>.cumulative(initial: R, operation: (T, R) -> R): List<R> = thi
|
|||||||
//Cumulative sum
|
//Cumulative sum
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element}
|
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun Iterable<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element}
|
fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element}
|
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element}
|
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun Sequence<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element}
|
fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element}
|
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfDouble")
|
@JvmName("cumulativeSumOfDouble")
|
||||||
fun List<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element}
|
fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfInt")
|
@JvmName("cumulativeSumOfInt")
|
||||||
fun List<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element}
|
fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
|
||||||
|
|
||||||
@JvmName("cumulativeSumOfLong")
|
@JvmName("cumulativeSumOfLong")
|
||||||
fun List<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element}
|
fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
|
@ -8,30 +8,29 @@ package scientifik.kmath.misc
|
|||||||
*
|
*
|
||||||
* If step is negative, the same goes from upper boundary downwards
|
* If step is negative, the same goes from upper boundary downwards
|
||||||
*/
|
*/
|
||||||
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> {
|
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> =
|
||||||
return when {
|
when {
|
||||||
step == 0.0 -> error("Zero step in double progression")
|
step == 0.0 -> error("Zero step in double progression")
|
||||||
step > 0 -> sequence {
|
step > 0 -> sequence {
|
||||||
var current = start
|
var current = start
|
||||||
while (current <= endInclusive) {
|
while (current <= endInclusive) {
|
||||||
yield(current)
|
yield(current)
|
||||||
current += step
|
current += step
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else -> sequence {
|
||||||
|
var current = endInclusive
|
||||||
|
while (current >= start) {
|
||||||
|
yield(current)
|
||||||
|
current += step
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else -> sequence {
|
|
||||||
var current = endInclusive
|
|
||||||
while (current >= start) {
|
|
||||||
yield(current)
|
|
||||||
current += step
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
|
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
|
||||||
*/
|
*/
|
||||||
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
|
||||||
if (numPoints < 2) error("Can't generic grid with less than two points")
|
if (numPoints < 2) error("Can't create generic grid with less than two points")
|
||||||
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
|
||||||
}
|
}
|
@ -53,6 +53,7 @@ interface Space<T> {
|
|||||||
|
|
||||||
//TODO move to external extensions when they are available
|
//TODO move to external extensions when they are available
|
||||||
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
|
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
|
||||||
|
|
||||||
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
|
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,11 +79,11 @@ object DoubleField : ExtendedField<Double>, Norm<Double, Double> {
|
|||||||
/**
|
/**
|
||||||
* A field for double without boxing. Does not produce appropriate field element
|
* A field for double without boxing. Does not produce appropriate field element
|
||||||
*/
|
*/
|
||||||
object IntField : Field<Int>{
|
object IntField : Field<Int> {
|
||||||
override val zero: Int = 0
|
override val zero: Int = 0
|
||||||
override fun add(a: Int, b: Int): Int = a + b
|
override fun add(a: Int, b: Int): Int = a + b
|
||||||
override fun multiply(a: Int, b: Int): Int = a * b
|
override fun multiply(a: Int, b: Int): Int = a * b
|
||||||
override fun multiply(a: Int, k: Double): Int = (k*a).toInt()
|
override fun multiply(a: Int, k: Double): Int = (k * a).toInt()
|
||||||
override val one: Int = 1
|
override val one: Int = 1
|
||||||
override fun divide(a: Int, b: Int): Int = a / b
|
override fun divide(a: Int, b: Int): Int = a / b
|
||||||
}
|
}
|
@ -2,12 +2,15 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
|
|
||||||
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F> {
|
open class BufferNDField<T, F : Field<T>>(
|
||||||
|
final override val shape: IntArray,
|
||||||
|
final override val field: F,
|
||||||
|
val bufferFactory: BufferFactory<T>
|
||||||
|
) : NDField<T, F> {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> {
|
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> =
|
||||||
return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
|
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
|
||||||
}
|
|
||||||
|
|
||||||
open fun produceBuffered(initializer: F.(Int) -> T) =
|
open fun produceBuffered(initializer: F.(Int) -> T) =
|
||||||
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) })
|
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) })
|
||||||
@ -31,7 +34,10 @@ open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, fi
|
|||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, val buffer: Buffer<T>) : NDElement<T, F> {
|
class BufferNDElement<T, F : Field<T>>(
|
||||||
|
override val context: BufferNDField<T, F>,
|
||||||
|
val buffer: Buffer<T>
|
||||||
|
) : NDElement<T, F> {
|
||||||
|
|
||||||
override val self: NDStructure<T> get() = this
|
override val self: NDStructure<T> get() = this
|
||||||
override val shape: IntArray get() = context.shape
|
override val shape: IntArray get() = context.shape
|
||||||
@ -60,7 +66,7 @@ operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) =
|
|||||||
/**
|
/**
|
||||||
* Subtraction operation between [BufferNDElement] and single element
|
* Subtraction operation between [BufferNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
|
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
|
||||||
context.produceBuffered { i -> buffer[i] - arg }
|
context.produceBuffered { i -> buffer[i] - arg }
|
||||||
|
|
||||||
/* prod and div */
|
/* prod and div */
|
||||||
@ -68,11 +74,11 @@ operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
|
|||||||
/**
|
/**
|
||||||
* Product operation for [BufferNDElement] and single element
|
* Product operation for [BufferNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
|
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
|
||||||
context.produceBuffered { i -> buffer[i] * arg }
|
context.produceBuffered { i -> buffer[i] * arg }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Division operation between [BufferNDElement] and single element
|
* Division operation between [BufferNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
|
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
|
||||||
context.produceBuffered { i -> buffer[i] / arg }
|
context.produceBuffered { i -> buffer[i] / arg }
|
@ -139,14 +139,13 @@ inline fun <T> boxingBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = Lis
|
|||||||
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): Buffer<T> {
|
inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
return when (T::class) {
|
when (T::class) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
|
||||||
else -> boxingBuffer(size, initializer)
|
else -> boxingBuffer(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a boxing mutable buffer of given type
|
* Create a boxing mutable buffer of given type
|
||||||
@ -157,14 +156,13 @@ inline fun <T : Any> boxingMutableBuffer(size: Int, initializer: (Int) -> T): Mu
|
|||||||
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
inline fun <reified T : Any> inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
inline fun <reified T : Any> inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
|
||||||
return when (T::class) {
|
when (T::class) {
|
||||||
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
|
||||||
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
|
||||||
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
|
||||||
else -> boxingMutableBuffer(size, initializer)
|
else -> boxingMutableBuffer(size, initializer)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
|
||||||
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>
|
||||||
|
@ -23,25 +23,16 @@ inline class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>>(private val n
|
|||||||
|
|
||||||
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer)
|
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer)
|
||||||
|
|
||||||
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> {
|
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> =
|
||||||
return produce { with(field) { power(arg[it], pow) } }
|
produce { with(field) { power(arg[it], pow) } }
|
||||||
}
|
|
||||||
|
|
||||||
override fun exp(arg: NDStructure<T>): NDElement<T, F> {
|
override fun exp(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { exp(arg[it]) } }
|
||||||
return produce { with(field) { exp(arg[it]) } }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun ln(arg: NDStructure<T>): NDElement<T, F> {
|
override fun ln(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { ln(arg[it]) } }
|
||||||
return produce { with(field) { ln(arg[it]) } }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun sin(arg: NDStructure<T>): NDElement<T, F> {
|
override fun sin(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { sin(arg[it]) } }
|
||||||
return produce { with(field) { sin(arg[it]) } }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun cos(arg: NDStructure<T>): NDElement<T, F> {
|
override fun cos(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { cos(arg[it]) } }
|
||||||
return produce { with(field) { cos(arg[it]) } }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,8 +11,8 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Field for n-dimensional arrays.
|
* Field for n-dimensional arrays.
|
||||||
* @param shape - the list of dimensions of the array
|
* @property shape - the list of dimensions of the array
|
||||||
* @param field - operations field defined on individual array element
|
* @property field - operations field defined on individual array element
|
||||||
* @param T the type of the element contained in NDArray
|
* @param T the type of the element contained in NDArray
|
||||||
*/
|
*/
|
||||||
interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
||||||
@ -33,13 +33,8 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
|||||||
/**
|
/**
|
||||||
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
|
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
|
||||||
*/
|
*/
|
||||||
fun checkShape(vararg elements: NDStructure<T>) {
|
fun checkShape(vararg elements: NDStructure<T>) =
|
||||||
elements.forEach {
|
elements.forEach { if (!shape.contentEquals(it.shape)) throw ShapeMismatchException(shape, it.shape) }
|
||||||
if (!shape.contentEquals(it.shape)) {
|
|
||||||
throw ShapeMismatchException(shape, it.shape)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-by-element addition
|
* Element-by-element addition
|
||||||
@ -97,21 +92,17 @@ interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F
|
|||||||
/**
|
/**
|
||||||
* Create a platform-optimized NDArray of doubles
|
* Create a platform-optimized NDArray of doubles
|
||||||
*/
|
*/
|
||||||
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> {
|
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> =
|
||||||
return NDField.real(shape).produce(initializer)
|
NDField.real(shape).produce(initializer)
|
||||||
}
|
|
||||||
|
|
||||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> {
|
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> =
|
||||||
return real(intArrayOf(dim)) { initializer(it[0]) }
|
real(intArrayOf(dim)) { initializer(it[0]) }
|
||||||
}
|
|
||||||
|
|
||||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> {
|
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> =
|
||||||
return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||||
}
|
|
||||||
|
|
||||||
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> {
|
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> =
|
||||||
return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
}
|
|
||||||
|
|
||||||
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
|
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
|
||||||
// val field = NDField.real(shape)
|
// val field = NDField.real(shape)
|
||||||
@ -121,13 +112,11 @@ interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F
|
|||||||
/**
|
/**
|
||||||
* Simple boxing NDArray
|
* Simple boxing NDArray
|
||||||
*/
|
*/
|
||||||
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> {
|
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> =
|
||||||
return NDField.generic(shape, field).produce(initializer)
|
NDField.generic(shape, field).produce(initializer)
|
||||||
}
|
|
||||||
|
|
||||||
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> {
|
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> =
|
||||||
return NDField.inline(shape, field).produce(initializer)
|
NDField.inline(shape, field).produce(initializer)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,11 +19,8 @@ interface MutableNDStructure<T> : NDStructure<T> {
|
|||||||
operator fun set(index: IntArray, value: T)
|
operator fun set(index: IntArray, value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
|
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) =
|
||||||
elements().forEach { (index, oldValue) ->
|
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||||
this[index] = action(index, oldValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A way to convert ND index to linear one and back
|
* A way to convert ND index to linear one and back
|
||||||
@ -56,11 +53,11 @@ interface Strides {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Iterate over ND indices in a natural order
|
* Iterate over ND indices in a natural order
|
||||||
|
*
|
||||||
|
* TODO: introduce a fast way to calculate index of the next element?
|
||||||
*/
|
*/
|
||||||
fun indices(): Sequence<IntArray> {
|
fun indices(): Sequence<IntArray> =
|
||||||
//TODO introduce a fast way to calculate index of the next element?
|
(0 until linearSize).asSequence().map { index(it) }
|
||||||
return (0 until linearSize).asSequence().map { index(it) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
||||||
@ -128,25 +125,21 @@ abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
|
|||||||
/**
|
/**
|
||||||
* Boxing generic [NDStructure]
|
* Boxing generic [NDStructure]
|
||||||
*/
|
*/
|
||||||
class BufferNDStructure<T>(
|
class BufferNDStructure<T>(override val strides: Strides,
|
||||||
override val strides: Strides,
|
override val buffer: Buffer<T>) : GenericNDStructure<T, Buffer<T>>() {
|
||||||
override val buffer: Buffer<T>
|
|
||||||
) : GenericNDStructure<T, Buffer<T>>() {
|
|
||||||
|
|
||||||
init {
|
init {
|
||||||
if (strides.linearSize != buffer.size) {
|
if (strides.linearSize != buffer.size) {
|
||||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean =
|
||||||
return when {
|
when {
|
||||||
this === other -> true
|
this === other -> true
|
||||||
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
|
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
|
||||||
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
|
||||||
else -> false
|
else -> false
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
override fun hashCode(): Int {
|
||||||
var result = strides.hashCode()
|
var result = strides.hashCode()
|
||||||
@ -158,14 +151,13 @@ class BufferNDStructure<T>(
|
|||||||
/**
|
/**
|
||||||
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
|
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
|
||||||
*/
|
*/
|
||||||
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> {
|
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> =
|
||||||
return if (this is BufferNDStructure<T>) {
|
if (this is BufferNDStructure<T>) {
|
||||||
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
||||||
} else {
|
} else {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a NDStructure with explicit buffer factory
|
* Create a NDStructure with explicit buffer factory
|
||||||
|
@ -9,14 +9,13 @@ class RealNDField(shape: IntArray) : BufferNDField<Double, DoubleField>(shape, D
|
|||||||
/**
|
/**
|
||||||
* Inline map an NDStructure to
|
* Inline map an NDStructure to
|
||||||
*/
|
*/
|
||||||
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement {
|
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement =
|
||||||
return if (this is BufferNDElement<Double, *>) {
|
if (this is BufferNDElement<Double, *>) {
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
|
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
|
||||||
BufferNDElement(this@RealNDField, DoubleBuffer(array))
|
BufferNDElement(this@RealNDField, DoubleBuffer(array))
|
||||||
} else {
|
} else {
|
||||||
produce { index -> DoubleField.operation(get(index)) }
|
produce { index -> DoubleField.operation(get(index)) }
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
@ -58,16 +57,12 @@ inline fun BufferNDField<Double, DoubleField>.produceInline(crossinline initiali
|
|||||||
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
||||||
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
|
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
|
||||||
|
|
||||||
/* plus and minus */
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Summation operation for [BufferNDElement] and single element
|
* Summation operation for [BufferNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun RealNDElement.plus(arg: Double) =
|
operator fun RealNDElement.plus(arg: Double) = context.produceInline { i -> buffer[i] + arg }
|
||||||
context.produceInline { i -> buffer[i] + arg }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subtraction operation between [BufferNDElement] and single element
|
* Subtraction operation between [BufferNDElement] and single element
|
||||||
*/
|
*/
|
||||||
operator fun RealNDElement.minus(arg: Double) =
|
operator fun RealNDElement.minus(arg: Double) = context.produceInline { i -> buffer[i] - arg }
|
||||||
context.produceInline { i -> buffer[i] - arg }
|
|
Loading…
Reference in New Issue
Block a user