Use consistent code style #30

Merged
breandan merged 7 commits from dev into dev 2019-01-04 20:24:57 +03:00
19 changed files with 191 additions and 236 deletions
Showing only changes of commit 55ce9b4754 - Show all commits

View File

@ -17,9 +17,8 @@ interface ExpressionContext<T> {
} }
internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> { internal class VariableExpression<T>(val name: String, val default: T? = null) : Expression<T> {
override fun invoke(arguments: Map<String, T>): T { override fun invoke(arguments: Map<String, T>): T =
return arguments[name] ?: default ?: error("The parameter not found: $name") arguments[name] ?: default ?: error("Parameter not found: $name")
}
} }
internal class ConstantExpression<T>(val value: T) : Expression<T> { internal class ConstantExpression<T>(val value: T) : Expression<T> {

View File

@ -31,7 +31,7 @@ class FastHistogram(
// argument checks // argument checks
if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.")
if (lower.size != binNums.size) error("Dimension mismatch in bin count.") if (lower.size != binNums.size) error("Dimension mismatch in bin count.")
if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive") if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one axis is not strictly positive")
} }
@ -41,23 +41,19 @@ class FastHistogram(
/** /**
* Get internal [NDStructure] bin index for given axis * Get internal [NDStructure] bin index for given axis
*/ */
private fun getIndex(axis: Int, value: Double): Int { private fun getIndex(axis: Int, value: Double): Int =
return when { when {
value >= upper[axis] -> binNums[axis] + 1 // overflow value >= upper[axis] -> binNums[axis] + 1 // overflow
value < lower[axis] -> 0 // underflow value < lower[axis] -> 0 // underflow
else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1
} }
}
private fun getIndex(point: Buffer<out Double>): IntArray = IntArray(dimension) { getIndex(it, point[it]) } private fun getIndex(point: Buffer<out Double>): IntArray =
IntArray(dimension) { getIndex(it, point[it]) }
private fun getValue(index: IntArray): Long { private fun getValue(index: IntArray): Long = values[index].sum()
return values[index].sum()
}
fun getValue(point: Buffer<out Double>): Long { fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
return getValue(getIndex(point))
}
private fun getTemplate(index: IntArray): BinTemplate<Double> { private fun getTemplate(index: IntArray): BinTemplate<Double> {
val center = index.mapIndexed { axis, i -> val center = index.mapIndexed { axis, i ->
@ -70,9 +66,7 @@ class FastHistogram(
return BinTemplate(center, binSize) return BinTemplate(center, binSize)
} }
fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> { fun getTemplate(point: Buffer<out Double>): BinTemplate<Double> = getTemplate(getIndex(point))
return getTemplate(getIndex(point))
}
override fun get(point: Buffer<out Double>): PhantomBin<Double>? { override fun get(point: Buffer<out Double>): PhantomBin<Double>? {
val index = getIndex(point) val index = getIndex(point)
@ -85,16 +79,16 @@ class FastHistogram(
values[index].increment() values[index].increment()
} }
override fun iterator(): Iterator<PhantomBin<Double>> = values.elements().map { (index, value) -> override fun iterator(): Iterator<PhantomBin<Double>> =
values.elements().map { (index, value) ->
PhantomBin(getTemplate(index), value.sum()) PhantomBin(getTemplate(index), value.sum())
}.iterator() }.iterator()
/** /**
* Convert this histogram into NDStructure containing bin values but not bin descriptions * Convert this histogram into NDStructure containing bin values but not bin descriptions
*/ */
fun asNDStructure(): NDStructure<Number> { fun asNDStructure(): NDStructure<Number> =
return inlineNdStructure(this.values.shape) { values[it].sum() } inlineNdStructure(this.values.shape) { values[it].sum() }
}
/** /**
* Create a phantom lightweight immutable copy of this histogram * Create a phantom lightweight immutable copy of this histogram
@ -115,9 +109,8 @@ class FastHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram { fun fromRanges(vararg ranges: ClosedFloatingPointRange<Double>): FastHistogram =
return FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector()) FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector())
}
/** /**
* Use it like * Use it like
@ -128,13 +121,12 @@ class FastHistogram(
*) *)
*``` *```
*/ */
fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram { fun fromRanges(vararg ranges: Pair<ClosedFloatingPointRange<Double>, Int>): FastHistogram =
return FastHistogram( FastHistogram(
ListBuffer(ranges.map { it.first.start }), ListBuffer(ranges.map { it.first.start }),
ListBuffer(ranges.map { it.first.endInclusive }), ListBuffer(ranges.map { it.first.endInclusive }),
ranges.map { it.second }.toIntArray() ranges.map { it.second }.toIntArray()
) )
} }
}
} }

View File

@ -51,9 +51,8 @@ class PhantomHistogram<T : Comparable<T>>(
override val dimension: Int override val dimension: Int
get() = data.dimension get() = data.dimension
override fun iterator(): Iterator<PhantomBin<T>> { override fun iterator(): Iterator<PhantomBin<T>> =
return bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator() bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator()
}
override fun get(point: Point<out T>): PhantomBin<T>? { override fun get(point: Point<out T>): PhantomBin<T>? {
val template = bins.keys.find { it.contains(point) } val template = bins.keys.find { it.contains(point) }

View File

@ -86,7 +86,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
} }
/** /**
* In-place transformation for [MutableNDArray], using given transformation for each element * In-place transformation for [MutableNDStructure], using given transformation for each element
*/ */
operator fun <T> MutableNDStructure<T>.set(i: Int, j: Int, value: T) { operator fun <T> MutableNDStructure<T>.set(i: Int, j: Int, value: T) {
this[intArrayOf(i, j)] = value this[intArrayOf(i, j)] = value
@ -174,9 +174,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
* @return the pivot permutation vector * @return the pivot permutation vector
* @see .getP * @see .getP
*/ */
fun getPivot(): IntArray { fun getPivot(): IntArray = pivot.copyOf()
return pivot.copyOf()
}
} }

View File

@ -28,8 +28,8 @@ fun List<Double>.toVector() = Vector.real(this.size) { this[it] }
/** /**
* Convert matrix to vector if it is possible * Convert matrix to vector if it is possible
*/ */
fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> { fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> =
return if (this.numCols == 1) { if (this.numCols == 1) {
// if (this is ArrayMatrix) { // if (this is ArrayMatrix) {
// //Reuse existing underlying array // //Reuse existing underlying array
// ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array) // ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array)
@ -39,7 +39,6 @@ fun <T : Any, F : Ring<T>> Matrix<T, F>.toVector(): Vector<T, F> {
// } // }
Vector.generic(numRows, context.ring) { get(it, 0) } Vector.generic(numRows, context.ring) { get(it, 0) }
} else error("Can't convert matrix with more than one column to vector") } else error("Can't convert matrix with more than one column to vector")
}
fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> { fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
// val context = StructureMatrixContext(size, 1, context.space) // val context = StructureMatrixContext(size, 1, context.space)
@ -56,9 +55,8 @@ fun <T : Any, R : Ring<T>> Vector<T, R>.toMatrix(): Matrix<T, R> {
} }
object VectorL2Norm : Norm<Vector<out Number, *>, Double> { object VectorL2Norm : Norm<Vector<out Number, *>, Double> {
override fun norm(arg: Vector<out Number, *>): Double { override fun norm(arg: Vector<out Number, *>): Double =
return kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() }) kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() })
}
} }
typealias RealVector = Vector<Double, DoubleField> typealias RealVector = Vector<Double, DoubleField>

View File

@ -33,9 +33,11 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
val one get() = produce { i, j -> if (i == j) ring.one else ring.zero } val one get() = produce { i, j -> if (i == j) ring.one else ring.zero }
override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } } override fun add(a: Matrix<T, R>, b: Matrix<T, R>): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } }
override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } } override fun multiply(a: Matrix<T, R>, k: Double): Matrix<T, R> =
produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } }
companion object { companion object {
/** /**
@ -120,8 +122,8 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
private val strides = DefaultStrides(shape) private val strides = DefaultStrides(shape)
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> =
return if (rows == rowNum && columns == colNum) { if (rows == rowNum && columns == colNum) {
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) } val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(this, structure) StructureMatrix(this, structure)
} else { } else {
@ -129,12 +131,15 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(context, structure) StructureMatrix(context, structure)
} }
}
override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer) override fun point(size: Int, initializer: (Int) -> T): Point<T> = bufferFactory(size, initializer)
} }
data class StructureMatrix<T : Any, R : Ring<T>>(override val context: StructureMatrixSpace<T, R>, val structure: NDStructure<T>) : Matrix<T, R> { data class StructureMatrix<T : Any, R : Ring<T>>(
override val context: StructureMatrixSpace<T, R>,
val structure: NDStructure<T>
) : Matrix<T, R> {
init { init {
if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) { if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) {
error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found") error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found")

View File

@ -33,9 +33,8 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/** /**
* Non-boxing double vector space * Non-boxing double vector space
*/ */
fun real(size: Int): BufferVectorSpace<Double, DoubleField> { fun real(size: Int): BufferVectorSpace<Double, DoubleField> =
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) } realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) }
}
/** /**
* A structured vector space with custom buffer * A structured vector space with custom buffer
@ -74,11 +73,7 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T,
} }
} }
data class BufferVectorSpace<T : Any, S : Space<T>>( data class BufferVectorSpace<T : Any, S : Space<T>>(override val size: Int, override val space: S, val bufferFactory: BufferFactory<T>) : VectorSpace<T, S> {
override val size: Int,
override val space: S,
val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer)) override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer))
} }
@ -91,9 +86,7 @@ data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace
} }
} }
override fun get(index: Int): T { override fun get(index: Int): T = buffer[index]
return buffer[index]
}
override val self: BufferVector<T, S> get() = this override val self: BufferVector<T, S> get() = this

View File

@ -8,8 +8,8 @@ package scientifik.kmath.misc
* *
* If step is negative, the same goes from upper boundary downwards * If step is negative, the same goes from upper boundary downwards
*/ */
fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> { fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double> =
return when { when {
step == 0.0 -> error("Zero step in double progression") step == 0.0 -> error("Zero step in double progression")
step > 0 -> sequence { step > 0 -> sequence {
var current = start var current = start
@ -26,12 +26,11 @@ fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double>
} }
} }
} }
}
/** /**
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
*/ */
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't generic grid with less than two points") if (numPoints < 2) error("Can't create generic grid with less than two points")
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
} }

View File

@ -53,6 +53,7 @@ interface Space<T> {
//TODO move to external extensions when they are available //TODO move to external extensions when they are available
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right } fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right } fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
} }

View File

@ -2,12 +2,15 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F> { open class BufferNDField<T, F : Field<T>>(
final override val shape: IntArray,
final override val field: F,
val bufferFactory: BufferFactory<T>
) : NDField<T, F> {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> { override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> =
return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }) BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
}
open fun produceBuffered(initializer: F.(Int) -> T) = open fun produceBuffered(initializer: F.(Int) -> T) =
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) }) BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) })
@ -31,7 +34,10 @@ open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, fi
// } // }
} }
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, val buffer: Buffer<T>) : NDElement<T, F> { class BufferNDElement<T, F : Field<T>>(
override val context: BufferNDField<T, F>,
val buffer: Buffer<T>
) : NDElement<T, F> {
override val self: NDStructure<T> get() = this override val self: NDStructure<T> get() = this
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape

View File

@ -139,14 +139,13 @@ inline fun <T> boxingBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = Lis
* Create most appropriate immutable buffer for given type avoiding boxing wherever possible * Create most appropriate immutable buffer for given type avoiding boxing wherever possible
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): Buffer<T> { inline fun <reified T : Any> inlineBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
return when (T::class) { when (T::class) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T> Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer<T>
else -> boxingBuffer(size, initializer) else -> boxingBuffer(size, initializer)
} }
}
/** /**
* Create a boxing mutable buffer of given type * Create a boxing mutable buffer of given type
@ -157,14 +156,13 @@ inline fun <T : Any> boxingMutableBuffer(size: Int, initializer: (Int) -> T): Mu
* Create most appropriate mutable buffer for given type avoiding boxing wherever possible * Create most appropriate mutable buffer for given type avoiding boxing wherever possible
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
inline fun <reified T : Any> inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> { inline fun <reified T : Any> inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> =
return when (T::class) { when (T::class) {
Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T> Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer<T>
Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T> Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer<T>
Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T> Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer<T>
else -> boxingMutableBuffer(size, initializer) else -> boxingMutableBuffer(size, initializer)
} }
}
typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T> typealias BufferFactory<T> = (Int, (Int) -> T) -> Buffer<T>
typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T> typealias MutableBufferFactory<T> = (Int, (Int) -> T) -> MutableBuffer<T>

View File

@ -23,25 +23,16 @@ inline class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>>(private val n
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer) override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer)
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> { override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> =
return produce { with(field) { power(arg[it], pow) } } produce { with(field) { power(arg[it], pow) } }
}
override fun exp(arg: NDStructure<T>): NDElement<T, F> { override fun exp(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { exp(arg[it]) } }
return produce { with(field) { exp(arg[it]) } }
}
override fun ln(arg: NDStructure<T>): NDElement<T, F> { override fun ln(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { ln(arg[it]) } }
return produce { with(field) { ln(arg[it]) } }
}
override fun sin(arg: NDStructure<T>): NDElement<T, F> { override fun sin(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { sin(arg[it]) } }
return produce { with(field) { sin(arg[it]) } }
}
override fun cos(arg: NDStructure<T>): NDElement<T, F> { override fun cos(arg: NDStructure<T>): NDElement<T, F> = produce { with(field) { cos(arg[it]) } }
return produce { with(field) { cos(arg[it]) } }
}
} }

View File

@ -11,8 +11,8 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
/** /**
* Field for n-dimensional arrays. * Field for n-dimensional arrays.
* @param shape - the list of dimensions of the array * @property shape - the list of dimensions of the array
* @param field - operations field defined on individual array element * @property field - operations field defined on individual array element
* @param T the type of the element contained in NDArray * @param T the type of the element contained in NDArray
*/ */
interface NDField<T, F : Field<T>> : Field<NDStructure<T>> { interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
@ -33,13 +33,8 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
/** /**
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field * Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
*/ */
fun checkShape(vararg elements: NDStructure<T>) { fun checkShape(vararg elements: NDStructure<T>) =
elements.forEach { elements.forEach { if (!shape.contentEquals(it.shape)) throw ShapeMismatchException(shape, it.shape) }
if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape)
}
}
}
/** /**
* Element-by-element addition * Element-by-element addition
@ -97,21 +92,17 @@ interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F
/** /**
* Create a platform-optimized NDArray of doubles * Create a platform-optimized NDArray of doubles
*/ */
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> { fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> =
return NDField.real(shape).produce(initializer) NDField.real(shape).produce(initializer)
}
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> { fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> =
return real(intArrayOf(dim)) { initializer(it[0]) } real(intArrayOf(dim)) { initializer(it[0]) }
}
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> { fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> =
return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> { fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> =
return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> { // inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
// val field = NDField.real(shape) // val field = NDField.real(shape)
@ -121,13 +112,11 @@ interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F
/** /**
* Simple boxing NDArray * Simple boxing NDArray
*/ */
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> { fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> =
return NDField.generic(shape, field).produce(initializer) NDField.generic(shape, field).produce(initializer)
}
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> { inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> =
return NDField.inline(shape, field).produce(initializer) NDField.inline(shape, field).produce(initializer)
}
} }
} }

View File

@ -19,11 +19,8 @@ interface MutableNDStructure<T> : NDStructure<T> {
operator fun set(index: IntArray, value: T) operator fun set(index: IntArray, value: T)
} }
fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) { fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) =
elements().forEach { (index, oldValue) -> elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
this[index] = action(index, oldValue)
}
}
/** /**
* A way to convert ND index to linear one and back * A way to convert ND index to linear one and back
@ -56,11 +53,11 @@ interface Strides {
/** /**
* Iterate over ND indices in a natural order * Iterate over ND indices in a natural order
*
* TODO: introduce a fast way to calculate index of the next element?
*/ */
fun indices(): Sequence<IntArray> { fun indices(): Sequence<IntArray> =
//TODO introduce a fast way to calculate index of the next element? (0 until linearSize).asSequence().map { index(it) }
return (0 until linearSize).asSequence().map { index(it) }
}
} }
class DefaultStrides private constructor(override val shape: IntArray) : Strides { class DefaultStrides private constructor(override val shape: IntArray) : Strides {
@ -128,25 +125,21 @@ abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
/** /**
* Boxing generic [NDStructure] * Boxing generic [NDStructure]
*/ */
class BufferNDStructure<T>( class BufferNDStructure<T>(override val strides: Strides,
override val strides: Strides, override val buffer: Buffer<T>) : GenericNDStructure<T, Buffer<T>>() {
override val buffer: Buffer<T>
) : GenericNDStructure<T, Buffer<T>>() {
init { init {
if (strides.linearSize != buffer.size) { if (strides.linearSize != buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
} }
} }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean =
return when { when {
this === other -> true this === other -> true
other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer) other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer)
other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] } other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] }
else -> false else -> false
} }
}
override fun hashCode(): Int { override fun hashCode(): Int {
var result = strides.hashCode() var result = strides.hashCode()
@ -158,14 +151,13 @@ class BufferNDStructure<T>(
/** /**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
*/ */
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> { inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> =
return if (this is BufferNDStructure<T>) { if (this is BufferNDStructure<T>) {
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else { } else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
} }
}
/** /**
* Create a NDStructure with explicit buffer factory * Create a NDStructure with explicit buffer factory

View File

@ -9,14 +9,13 @@ class RealNDField(shape: IntArray) : BufferNDField<Double, DoubleField>(shape, D
/** /**
* Inline map an NDStructure to * Inline map an NDStructure to
*/ */
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement { private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement =
return if (this is BufferNDElement<Double, *>) { if (this is BufferNDElement<Double, *>) {
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
BufferNDElement(this@RealNDField, DoubleBuffer(array)) BufferNDElement(this@RealNDField, DoubleBuffer(array))
} else { } else {
produce { index -> DoubleField.operation(get(index)) } produce { index -> DoubleField.operation(get(index)) }
} }
}
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@ -58,16 +57,12 @@ inline fun BufferNDField<Double, DoubleField>.produceInline(crossinline initiali
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) = operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
/* plus and minus */
/** /**
* Summation operation for [BufferNDElement] and single element * Summation operation for [BufferNDElement] and single element
*/ */
operator fun RealNDElement.plus(arg: Double) = operator fun RealNDElement.plus(arg: Double) = context.produceInline { i -> buffer[i] + arg }
context.produceInline { i -> buffer[i] + arg }
/** /**
* Subtraction operation between [BufferNDElement] and single element * Subtraction operation between [BufferNDElement] and single element
*/ */
operator fun RealNDElement.minus(arg: Double) = operator fun RealNDElement.minus(arg: Double) = context.produceInline { i -> buffer[i] - arg }
context.produceInline { i -> buffer[i] - arg }