use consistent code style and simplify

This commit is contained in:
breandan 2019-01-03 10:43:12 -05:00
parent 622a6a7756
commit 55ce9b4754
19 changed files with 191 additions and 236 deletions

View File

@ -17,9 +17,8 @@ interface ExpressionContext<T> {
}
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T {
return arguments[name] ?: default ?: error("The parameter not found: $name")
}
override fun invoke(arguments: Map<String, T>): T =
arguments[name] ?: default ?: error("Parameter not found: $name")
}
internal class ConstantExpression<T>(val value: T) : Expression<T> {

View File

@ -5,16 +5,16 @@ package scientifik.kmath.histogram
*/
expect class LongCounter(){
fun decrement()
fun increment()
fun reset()
fun sum(): Long
fun add(l:Long)
expect class LongCounter() {
fun decrement()
fun increment()
fun reset()
fun sum(): Long
fun add(l: Long)
}
expect class DoubleCounter(){
fun reset()
fun sum(): Double
expect class DoubleCounter() {
fun reset()
fun sum(): Double
fun add(d: Double)
}

View File

@ -31,7 +31,7 @@ class FastHistogram(
// argument checks
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive")
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one axis is not strictly positive")
}
@ -41,23 +41,19 @@ class FastHistogram(
/**
* Get internal [NDStructure] bin index for given axis
*/
private fun getIndex(axis: Int, value: Double): Int {
return when {
value >= upper[axis] -> binNums[axis] + 1 // overflow
value < lower[axis] -> 0 // underflow
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
}
}
private fun getIndex(axis: Int, value: Double): Int =
when {
value >= upper[axis] -> binNums[axis] + 1 // overflow
value < lower[axis] -> 0 // underflow
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
}
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) }
private fun getIndex(point: Buffer<out Double>): IntArray =
IntArray(dimension) { getIndex(it, point[it]) }
private fun getValue(index: IntArray): Long {
return values[index].sum()
}
private fun getValue(index: IntArray): Long = values[index].sum()
fun getValue(point: Buffer<out Double>): Long {
return getValue(getIndex(point))
}
fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
private fun getTemplate(index: IntArray): BinTemplate<Double> {
val center = index.mapIndexed { axis, i ->
@ -70,9 +66,7 @@ class FastHistogram(
return BinTemplate(center, binSize)
}
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> {
return getTemplate(getIndex(point))
}
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> = getTemplate(getIndex(point))
override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
val index = getIndex(point)
@ -85,16 +79,16 @@ class FastHistogram(
values[index].increment()
}
override fun iterator(): Iterator<PhantomBin<Double>> = values.elements().map { (index, value) ->
PhantomBin(getTemplate(index), value.sum())
}.iterator()
override fun iterator(): Iterator<PhantomBin<Double>> =
values.elements().map { (index, value) ->
PhantomBin(getTemplate(index), value.sum())
}.iterator()
/**
* Convert this histogram into NDStructure containing bin values but not bin descriptions
*/
fun asNDStructure(): NDStructure<Number> {
return inlineNdStructure(this.values.shape) { values[it].sum() }
}
fun asNDStructure(): NDStructure<Number> =
inlineNdStructure(this.values.shape) { values[it].sum() }
/**
* Create a phantom lightweight immutable copy of this histogram
@ -115,9 +109,8 @@ class FastHistogram(
*)
*```
*/
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram {
return FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector())
}
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram =
FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector())
/**
* Use it like
@ -128,13 +121,12 @@ class FastHistogram(
*)
*```
*/
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram {
return FastHistogram(
ListBuffer(ranges.map { it.first.start }),
ListBuffer(ranges.map { it.first.endInclusive }),
ranges.map { it.second }.toIntArray()
)
}
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram =
FastHistogram(
ListBuffer(ranges.map { it.first.start }),
ListBuffer(ranges.map { it.first.endInclusive }),
ranges.map { it.second }.toIntArray()
)
}
}

View File

@ -12,7 +12,7 @@ typealias RealPoint = Buffer<Double>
* A simple geometric domain
* TODO move to geometry module
*/
interface Domain<T: Any> {
interface Domain<T : Any> {
operator fun contains(vector: Point<out T>): Boolean
val dimension: Int
}
@ -20,7 +20,7 @@ interface Domain<T: Any> {
/**
* The bin in the histogram. The histogram is by definition always done in the real space
*/
interface Bin<T: Any> : Domain<T> {
interface Bin<T : Any> : Domain<T> {
/**
* The value of this bin
*/
@ -28,7 +28,7 @@ interface Bin<T: Any> : Domain<T> {
val center: Point<T>
}
interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
interface Histogram<T : Any, out B : Bin<T>> : Iterable<B> {
/**
* Find existing bin, corresponding to given coordinates
@ -42,7 +42,7 @@ interface Histogram<T: Any, out B : Bin<T>> : Iterable<B> {
}
interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{
interface MutableHistogram<T : Any, out B : Bin<T>> : Histogram<T, B> {
/**
* Increment appropriate bin
@ -50,14 +50,14 @@ interface MutableHistogram<T: Any, out B : Bin<T>>: Histogram<T,B>{
fun put(point: Point<out T>, weight: Double = 1.0)
}
fun <T: Any> MutableHistogram<T,*>.put(vararg point: T) = put(ArrayBuffer(point))
fun <T : Any> MutableHistogram<T, *>.put(vararg point: T) = put(ArrayBuffer(point))
fun MutableHistogram<Double,*>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double,*>.put(vararg point: Double) = put(DoubleBuffer(point))
fun MutableHistogram<Double, *>.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray()))
fun MutableHistogram<Double, *>.put(vararg point: Double) = put(DoubleBuffer(point))
fun <T: Any> MutableHistogram<T,*>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
fun <T : Any> MutableHistogram<T, *>.fill(sequence: Iterable<Point<T>>) = sequence.forEach { put(it) }
/**
* Pass a sequence builder into histogram
*/
fun <T: Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable())
fun <T : Any> MutableHistogram<T, *>.fill(buider: suspend SequenceScope<Point<T>>.() -> Unit) = fill(sequence(buider).asIterable())

View File

@ -8,8 +8,8 @@ import scientifik.kmath.structures.asSequence
data class BinTemplate<T : Comparable<T>>(val center: Vector<T, *>, val sizes: Point<T>) {
fun contains(vector: Point<out T>): Boolean {
if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}")
val upper = center.context.run { center + sizes / 2.0}
val lower = center.context.run {center - sizes / 2.0}
val upper = center.context.run { center + sizes / 2.0 }
val lower = center.context.run { center - sizes / 2.0 }
return vector.asSequence().mapIndexed { i, value ->
value in lower[i]..upper[i]
}.all { it }
@ -51,9 +51,8 @@ class PhantomHistogram<T : Comparable<T>>(
override val dimension: Int
get() = data.dimension
override fun iterator(): Iterator<PhantomBin<T>> {
return bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
}
override fun iterator(): Iterator<PhantomBin<T>> =
bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
override fun get(point: Point<out T>): PhantomBin<T>? {
val template = bins.keys.find { it.contains(point) }

View File

@ -86,7 +86,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
}
/**
* In-place transformation for [MutableNDArray], using given transformation for each element
* In-place transformation for [MutableNDStructure], using given transformation for each element
*/
operator fun <T> MutableNDStructure<T>.set(i: Int, j: Int, value: T) {
this[intArrayOf(i, j)] = value
@ -174,9 +174,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
* @return the pivot permutation vector
* @see .getP
*/
fun getPivot(): IntArray {
return pivot.copyOf()
}
fun getPivot(): IntArray = pivot.copyOf()
}

View File

@ -28,8 +28,8 @@ fun List<Double>.toVector() = Vector.real(this.size) { this[it] }
/**
* Convert matrix to vector if it is possible
*/
fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> {
return if (this.numCols == 1) {
fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> =
if (this.numCols == 1) {
// if (this is ArrayMatrix) {
// //Reuse existing underlying array
// ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array)
@ -37,9 +37,8 @@ fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> {
// //Generic vector
// vector(rows, context.field) { get(it, 0) }
// }
Vector.generic(numRows, context.ring) { get(it, 0) }
} else error("Can't convert matrix with more than one column to vector")
}
Vector.generic(numRows, context.ring) { get(it, 0) }
} else error("Can't convert matrix with more than one column to vector")
fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
// val context = StructureMatrixContext(size, 1, context.space)
@ -56,9 +55,8 @@ fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
}
object VectorL2Norm : Norm<Vector<out Number, *>, Double> {
override fun norm(arg: Vector<out Number, *>): Double {
return kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
}
override fun norm(arg: Vector<out Number, *>): Double =
kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
}
typealias RealVector = Vector<Double, DoubleField>

View File

@ -33,9 +33,11 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
val one get() = produce { i, j -> if (i == j) ring.one else ring.zero }
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
companion object {
/**
@ -120,21 +122,24 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
private val strides = DefaultStrides(shape)
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> {
return if (rows == rowNum && columns == colNum) {
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(this, structure)
} else {
val context = StructureMatrixSpace(rows, columns, ring, bufferFactory)
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(context, structure)
}
}
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> =
if (rows == rowNum && columns == colNum) {
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(this, structure)
} else {
val context = StructureMatrixSpace(rows, columns, ring, bufferFactory)
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(context, structure)
}
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
}
data class StructureMatrix<T : Any, R : Ring<T>>(override val context: StructureMatrixSpace<T, R>, val structure: NDStructure<T>) : Matrix<T, R> {
data class StructureMatrix<T : Any, R : Ring<T>>(
override val context: StructureMatrixSpace<T, R>,
val structure: NDStructure<T>
) : Matrix<T, R> {
init {
if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) {
error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found")

View File

@ -33,9 +33,8 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/**
* Non-boxing double vector space
*/
fun real(size: Int): BufferVectorSpace<Double, DoubleField> {
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) }
}
fun real(size: Int): BufferVectorSpace<Double, DoubleField> =
realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) }
/**
* A structured vector space with custom buffer
@ -69,16 +68,12 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T,
VectorSpace.buffered(size, field).produce(initializer)
fun real(size: Int, initializer: (Int) -> Double) = VectorSpace.real(size).produce(initializer)
fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce{elements[it]}
fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce { elements[it] }
}
}
data class BufferVectorSpace<T : Any, S : Space<T>>(
override val size: Int,
override val space: S,
val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> {
data class BufferVectorSpace<T : Any, S : Space<T>>(override val size: Int, override val space: S, val bufferFactory: BufferFactory<T>) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer))
}
@ -91,9 +86,7 @@ data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace
}
}
override fun get(index: Int): T {
return buffer[index]
}
override fun get(index: Int): T = buffer[index]
override val self: BufferVector<T, S> get() = this

View File

@ -32,28 +32,28 @@ fun <T, R> List<T>.cumulative(initial: R, operation: (T, R) -> R): List<R> = thi
//Cumulative sum
@JvmName("cumulativeSumOfDouble")
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element}
fun Iterable<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt")
fun Iterable<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element}
fun Iterable<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong")
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element}
fun Iterable<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
@JvmName("cumulativeSumOfDouble")
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element}
fun Sequence<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt")
fun Sequence<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element}
fun Sequence<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong")
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element}
fun Sequence<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }
@JvmName("cumulativeSumOfDouble")
fun List<Double>.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element}
fun List<Double>.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfInt")
fun List<Int>.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element}
fun List<Int>.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element }
@JvmName("cumulativeSumOfLong")
fun List<Long>.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element}
fun List<Long>.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element }

View File

@ -8,30 +8,29 @@ package scientifik.kmath.misc
*
* If step is negative, the same goes from upper boundary downwards
*/
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> {
return when {
step == 0.0 -> error("Zero step in double progression")
step > 0 -> sequence {
var current = start
while (current <= endInclusive) {
yield(current)
current += step
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> =
when {
step == 0.0 -> error("Zero step in double progression")
step > 0 -> sequence {
var current = start
while (current <= endInclusive) {
yield(current)
current += step
}
}
else -> sequence {
var current = endInclusive
while (current >= start) {
yield(current)
current += step
}
}
}
else -> sequence {
var current = endInclusive
while (current >= start) {
yield(current)
current += step
}
}
}
}
/**
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
*/
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't generic grid with less than two points")
if (numPoints < 2) error("Can't create generic grid with less than two points")
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
}

View File

@ -53,6 +53,7 @@ interface Space<T> {
//TODO move to external extensions when they are available
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
}

View File

@ -79,11 +79,11 @@ object DoubleField : ExtendedField<Double>, Norm<Double, Double> {
/**
* A field for double without boxing. Does not produce appropriate field element
*/
object IntField : Field<Int>{
object IntField : Field<Int> {
override val zero: Int = 0
override fun add(a: Int, b: Int): Int = a + b
override fun multiply(a: Int, b: Int): Int = a * b
override fun multiply(a: Int, k: Double): Int = (k*a).toInt()
override fun multiply(a: Int, k: Double): Int = (k * a).toInt()
override val one: Int = 1
override fun divide(a: Int, b: Int): Int = a / b
}

View File

@ -2,12 +2,15 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F> {
open class BufferNDField<T, F : Field<T>>(
final override val shape: IntArray,
final override val field: F,
val bufferFactory: BufferFactory<T>
) : NDField<T, F> {
val strides = DefaultStrides(shape)
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> {
return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
}
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> =
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
open fun produceBuffered(initializer: F.(Int) -> T) =
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) })
@ -31,7 +34,10 @@ open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, fi
// }
}
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, val buffer: Buffer<T>) : NDElement<T, F> {
class BufferNDElement<T, F : Field<T>>(
override val context: BufferNDField<T, F>,
val buffer: Buffer<T>
) : NDElement<T, F> {
override val self: NDStructure<T> get() = this
override val shape: IntArray get() = context.shape
@ -60,7 +66,7 @@ operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) =
/**
* Subtraction operation between [BufferNDElement] and single element
*/
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
context.produceBuffered { i -> buffer[i] - arg }
/* prod and div */
@ -68,11 +74,11 @@ operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
/**
* Product operation for [BufferNDElement] and single element
*/
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
context.produceBuffered { i -> buffer[i] * arg }
/**
* Division operation between [BufferNDElement] and single element
*/
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
context.produceBuffered { i -> buffer[i] / arg }

View File

@ -139,14 +139,13 @@ inline fun <T> boxingBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = Lis
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): Buffer<T> {
return when (T::class) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
else -> boxingBuffer(size, initializer)
}
}
inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
when (T::class) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
else -> boxingBuffer(size, initializer)
}
/**
* Create a boxing mutable buffer of given type
@ -157,14 +156,13 @@ inline fun <T : Any> boxingMutableBuffer(size: Int, initializer: (Int) -> T): Mu
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> {
return when (T::class) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
else -> boxingMutableBuffer(size, initializer)
}
}
inline fun <reified T : Any> inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
when (T::class) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
else -> boxingMutableBuffer(size, initializer)
}
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>

View File

@ -23,25 +23,16 @@ inline class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>>(private val n
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer)
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> {
return produce { with(field) { power(arg[it], pow) } }
}
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> =
produce { with(field) { power(arg[it], pow) } }
override fun exp(arg: NDStructure<T>): NDElement<T, F> {
return produce { with(field) { exp(arg[it]) } }
}
override fun exp(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { exp(arg[it]) } }
override fun ln(arg: NDStructure<T>): NDElement<T, F> {
return produce { with(field) { ln(arg[it]) } }
}
override fun ln(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { ln(arg[it]) } }
override fun sin(arg: NDStructure<T>): NDElement<T, F> {
return produce { with(field) { sin(arg[it]) } }
}
override fun sin(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { sin(arg[it]) } }
override fun cos(arg: NDStructure<T>): NDElement<T, F> {
return produce { with(field) { cos(arg[it]) } }
}
override fun cos(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { cos(arg[it]) } }
}

View File

@ -11,8 +11,8 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
/**
* Field for n-dimensional arrays.
* @param shape - the list of dimensions of the array
* @param field - operations field defined on individual array element
* @property shape - the list of dimensions of the array
* @property field - operations field defined on individual array element
* @param T the type of the element contained in NDArray
*/
interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
@ -33,13 +33,8 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
/**
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
*/
fun checkShape(vararg elements: NDStructure<T>) {
elements.forEach {
if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape)
}
}
}
fun checkShape(vararg elements: NDStructure<T>) =
elements.forEach { if (!shape.contentEquals(it.shape)) throw ShapeMismatchException(shape, it.shape) }
/**
* Element-by-element addition
@ -97,21 +92,17 @@ interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F
/**
* Create a platform-optimized NDArray of doubles
*/
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> {
return NDField.real(shape).produce(initializer)
}
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> =
NDField.real(shape).produce(initializer)
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim)) { initializer(it[0]) }
}
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> =
real(intArrayOf(dim)) { initializer(it[0]) }
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> =
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> =
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
// val field = NDField.real(shape)
@ -121,13 +112,11 @@ interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F
/**
* Simple boxing NDArray
*/
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> {
return NDField.generic(shape, field).produce(initializer)
}
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> =
NDField.generic(shape, field).produce(initializer)
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> {
return NDField.inline(shape, field).produce(initializer)
}
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> =
NDField.inline(shape, field).produce(initializer)
}
}

View File

@ -19,11 +19,8 @@ interface MutableNDStructure<T> : NDStructure<T> {
operator fun set(index: IntArray, value: T)
}
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
elements().forEach { (index, oldValue) ->
this[index] = action(index, oldValue)
}
}
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) =
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
/**
* A way to convert ND index to linear one and back
@ -56,11 +53,11 @@ interface Strides {
/**
* Iterate over ND indices in a natural order
*
* TODO: introduce a fast way to calculate index of the next element?
*/
fun indices(): Sequence<IntArray> {
//TODO introduce a fast way to calculate index of the next element?
return (0 until linearSize).asSequence().map { index(it) }
}
fun indices(): Sequence<IntArray> =
(0 until linearSize).asSequence().map { index(it) }
}
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
@ -128,25 +125,21 @@ abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
/**
* Boxing generic [NDStructure]
*/
class BufferNDStructure<T>(
override val strides: Strides,
override val buffer: Buffer<T>
) : GenericNDStructure<T, Buffer<T>>() {
class BufferNDStructure<T>(override val strides: Strides,
override val buffer: Buffer<T>) : GenericNDStructure<T, Buffer<T>>() {
init {
if (strides.linearSize != buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
}
}
override fun equals(other: Any?): Boolean {
return when {
this === other -> true
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
else -> false
}
}
override fun equals(other: Any?): Boolean =
when {
this === other -> true
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
else -> false
}
override fun hashCode(): Int {
var result = strides.hashCode()
@ -158,14 +151,13 @@ class BufferNDStructure<T>(
/**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
*/
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> {
return if (this is BufferNDStructure<T>) {
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else {
val strides = DefaultStrides(shape)
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
}
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> =
if (this is BufferNDStructure<T>) {
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else {
val strides = DefaultStrides(shape)
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
/**
* Create a NDStructure with explicit buffer factory

View File

@ -9,14 +9,13 @@ class RealNDField(shape: IntArray) : BufferNDField<Double, DoubleField>(shape, D
/**
* Inline map an NDStructure to
*/
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement {
return if (this is BufferNDElement<Double, *>) {
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
BufferNDElement(this@RealNDField, DoubleBuffer(array))
} else {
produce { index -> DoubleField.operation(get(index)) }
}
}
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement =
if (this is BufferNDElement<Double, *>) {
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
BufferNDElement(this@RealNDField, DoubleBuffer(array))
} else {
produce { index -> DoubleField.operation(get(index)) }
}
@Suppress("OVERRIDE_BY_INLINE")
@ -58,16 +57,12 @@ inline fun BufferNDField<Double, DoubleField>.produceInline(crossinline initiali
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
/* plus and minus */
/**
* Summation operation for [BufferNDElement] and single element
*/
operator fun RealNDElement.plus(arg: Double) =
context.produceInline { i -> buffer[i] + arg }
operator fun RealNDElement.plus(arg: Double) = context.produceInline { i -> buffer[i] + arg }
/**
* Subtraction operation between [BufferNDElement] and single element
*/
operator fun RealNDElement.minus(arg: Double) =
context.produceInline { i -> buffer[i] - arg }
operator fun RealNDElement.minus(arg: Double) = context.produceInline { i -> buffer[i] - arg }