diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt index 0a34b536c..47d07c63e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/Expression.kt @@ -17,9 +17,8 @@ interface ExpressionContext { } internal class VariableExpression(val name: String, val default: T? = null) : Expression { - override fun invoke(arguments: Map): T { - return arguments[name] ?: default ?: error("The parameter not found: $name") - } + override fun invoke(arguments: Map): T = + arguments[name] ?: default ?: error("Parameter not found: $name") } internal class ConstantExpression(val value: T) : Expression { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt index f6cc1f822..9a470014a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Counters.kt @@ -5,16 +5,16 @@ package scientifik.kmath.histogram */ -expect class LongCounter(){ - fun decrement() - fun increment() - fun reset() - fun sum(): Long - fun add(l:Long) +expect class LongCounter() { + fun decrement() + fun increment() + fun reset() + fun sum(): Long + fun add(l: Long) } -expect class DoubleCounter(){ - fun reset() - fun sum(): Double +expect class DoubleCounter() { + fun reset() + fun sum(): Double fun add(d: Double) } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt index 705a8e7ca..e3a3ee0d1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/FastHistogram.kt @@ -31,7 +31,7 @@ class FastHistogram( // argument checks if (lower.size != upper.size) error("Dimension mismatch in histogram lower and upper limits.") if (lower.size != binNums.size) error("Dimension mismatch in bin count.") - if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one of axis is not strictly positive") + if ((upper - lower).asSequence().any { it <= 0 }) error("Range for one axis is not strictly positive") } @@ -41,23 +41,19 @@ class FastHistogram( /** * Get internal [NDStructure] bin index for given axis */ - private fun getIndex(axis: Int, value: Double): Int { - return when { - value >= upper[axis] -> binNums[axis] + 1 // overflow - value < lower[axis] -> 0 // underflow - else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 - } - } + private fun getIndex(axis: Int, value: Double): Int = + when { + value >= upper[axis] -> binNums[axis] + 1 // overflow + value < lower[axis] -> 0 // underflow + else -> floor((value - lower[axis]) / binSize[axis]).toInt() + 1 + } - private fun getIndex(point: Buffer): IntArray = IntArray(dimension) { getIndex(it, point[it]) } + private fun getIndex(point: Buffer): IntArray = + IntArray(dimension) { getIndex(it, point[it]) } - private fun getValue(index: IntArray): Long { - return values[index].sum() - } + private fun getValue(index: IntArray): Long = values[index].sum() - fun getValue(point: Buffer): Long { - return getValue(getIndex(point)) - } + fun getValue(point: Buffer): Long = getValue(getIndex(point)) private fun getTemplate(index: IntArray): BinTemplate { val center = index.mapIndexed { axis, i -> @@ -70,9 +66,7 @@ class FastHistogram( return BinTemplate(center, binSize) } - fun getTemplate(point: Buffer): BinTemplate { - return getTemplate(getIndex(point)) - } + fun getTemplate(point: Buffer): BinTemplate = getTemplate(getIndex(point)) override fun get(point: Buffer): PhantomBin? { val index = getIndex(point) @@ -85,16 +79,16 @@ class FastHistogram( values[index].increment() } - override fun iterator(): Iterator> = values.elements().map { (index, value) -> - PhantomBin(getTemplate(index), value.sum()) - }.iterator() + override fun iterator(): Iterator> = + values.elements().map { (index, value) -> + PhantomBin(getTemplate(index), value.sum()) + }.iterator() /** * Convert this histogram into NDStructure containing bin values but not bin descriptions */ - fun asNDStructure(): NDStructure { - return inlineNdStructure(this.values.shape) { values[it].sum() } - } + fun asNDStructure(): NDStructure = + inlineNdStructure(this.values.shape) { values[it].sum() } /** * Create a phantom lightweight immutable copy of this histogram @@ -115,9 +109,8 @@ class FastHistogram( *) *``` */ - fun fromRanges(vararg ranges: ClosedFloatingPointRange): FastHistogram { - return FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector()) - } + fun fromRanges(vararg ranges: ClosedFloatingPointRange): FastHistogram = + FastHistogram(ranges.map { it.start }.toVector(), ranges.map { it.endInclusive }.toVector()) /** * Use it like @@ -128,13 +121,12 @@ class FastHistogram( *) *``` */ - fun fromRanges(vararg ranges: Pair, Int>): FastHistogram { - return FastHistogram( - ListBuffer(ranges.map { it.first.start }), - ListBuffer(ranges.map { it.first.endInclusive }), - ranges.map { it.second }.toIntArray() - ) - } + fun fromRanges(vararg ranges: Pair, Int>): FastHistogram = + FastHistogram( + ListBuffer(ranges.map { it.first.start }), + ListBuffer(ranges.map { it.first.endInclusive }), + ranges.map { it.second }.toIntArray() + ) } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt index 08214142e..90251b9d1 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/Histogram.kt @@ -12,7 +12,7 @@ typealias RealPoint = Buffer * A simple geometric domain * TODO move to geometry module */ -interface Domain { +interface Domain { operator fun contains(vector: Point): Boolean val dimension: Int } @@ -20,7 +20,7 @@ interface Domain { /** * The bin in the histogram. The histogram is by definition always done in the real space */ -interface Bin : Domain { +interface Bin : Domain { /** * The value of this bin */ @@ -28,7 +28,7 @@ interface Bin : Domain { val center: Point } -interface Histogram> : Iterable { +interface Histogram> : Iterable { /** * Find existing bin, corresponding to given coordinates @@ -42,7 +42,7 @@ interface Histogram> : Iterable { } -interface MutableHistogram>: Histogram{ +interface MutableHistogram> : Histogram { /** * Increment appropriate bin @@ -50,14 +50,14 @@ interface MutableHistogram>: Histogram{ fun put(point: Point, weight: Double = 1.0) } -fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) +fun MutableHistogram.put(vararg point: T) = put(ArrayBuffer(point)) -fun MutableHistogram.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) -fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) +fun MutableHistogram.put(vararg point: Number) = put(DoubleBuffer(point.map { it.toDouble() }.toDoubleArray())) +fun MutableHistogram.put(vararg point: Double) = put(DoubleBuffer(point)) -fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } +fun MutableHistogram.fill(sequence: Iterable>) = sequence.forEach { put(it) } /** * Pass a sequence builder into histogram */ -fun MutableHistogram.fill(buider: suspend SequenceScope>.() -> Unit) = fill(sequence(buider).asIterable()) \ No newline at end of file +fun MutableHistogram.fill(buider: suspend SequenceScope>.() -> Unit) = fill(sequence(buider).asIterable()) \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt index ffffb0d7d..f9ec68f73 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/histogram/PhantomHistogram.kt @@ -8,8 +8,8 @@ import scientifik.kmath.structures.asSequence data class BinTemplate>(val center: Vector, val sizes: Point) { fun contains(vector: Point): Boolean { if (vector.size != center.size) error("Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}") - val upper = center.context.run { center + sizes / 2.0} - val lower = center.context.run {center - sizes / 2.0} + val upper = center.context.run { center + sizes / 2.0 } + val lower = center.context.run { center - sizes / 2.0 } return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it } @@ -51,9 +51,8 @@ class PhantomHistogram>( override val dimension: Int get() = data.dimension - override fun iterator(): Iterator> { - return bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator() - } + override fun iterator(): Iterator> = + bins.asSequence().map { entry -> PhantomBin(entry.key, data[entry.value]) }.iterator() override fun get(point: Point): PhantomBin? { val template = bins.keys.find { it.contains(point) } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt index 3f14787f2..d94d4c7c9 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt @@ -86,7 +86,7 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr } /** - * In-place transformation for [MutableNDArray], using given transformation for each element + * In-place transformation for [MutableNDStructure], using given transformation for each element */ operator fun MutableNDStructure.set(i: Int, j: Int, value: T) { this[intArrayOf(i, j)] = value @@ -174,9 +174,7 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr * @return the pivot permutation vector * @see .getP */ - fun getPivot(): IntArray { - return pivot.copyOf() - } + fun getPivot(): IntArray = pivot.copyOf() } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt index c2e54bf3b..7108f4865 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LinearAlgrebra.kt @@ -28,8 +28,8 @@ fun List.toVector() = Vector.real(this.size) { this[it] } /** * Convert matrix to vector if it is possible */ -fun > Matrix.toVector(): Vector { - return if (this.numCols == 1) { +fun > Matrix.toVector(): Vector = + if (this.numCols == 1) { // if (this is ArrayMatrix) { // //Reuse existing underlying array // ArrayVector(ArrayVectorSpace(rows, context.field, context.ndFactory), array) @@ -37,9 +37,8 @@ fun > Matrix.toVector(): Vector { // //Generic vector // vector(rows, context.field) { get(it, 0) } // } - Vector.generic(numRows, context.ring) { get(it, 0) } - } else error("Can't convert matrix with more than one column to vector") -} + Vector.generic(numRows, context.ring) { get(it, 0) } + } else error("Can't convert matrix with more than one column to vector") fun > Vector.toMatrix(): Matrix { // val context = StructureMatrixContext(size, 1, context.space) @@ -56,9 +55,8 @@ fun > Vector.toMatrix(): Matrix { } object VectorL2Norm : Norm, Double> { - override fun norm(arg: Vector): Double { - return kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() }) - } + override fun norm(arg: Vector): Double = + kotlin.math.sqrt(arg.asSequence().sumByDouble { it.toDouble() }) } typealias RealVector = Vector diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt index 13a54cb5d..34ce820ec 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Matrix.kt @@ -33,9 +33,11 @@ interface MatrixSpace> : Space> { val one get() = produce { i, j -> if (i == j) ring.one else ring.zero } - override fun add(a: Matrix, b: Matrix): Matrix = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } } + override fun add(a: Matrix, b: Matrix): Matrix = + produce(rowNum, colNum) { i, j -> ring.run { a[i, j] + b[i, j] } } - override fun multiply(a: Matrix, k: Double): Matrix = produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } } + override fun multiply(a: Matrix, k: Double): Matrix = + produce(rowNum, colNum) { i, j -> ring.run { a[i, j] * k } } companion object { /** @@ -120,21 +122,24 @@ data class StructureMatrixSpace>( private val strides = DefaultStrides(shape) - override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix { - return if (rows == rowNum && columns == colNum) { - val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) } - StructureMatrix(this, structure) - } else { - val context = StructureMatrixSpace(rows, columns, ring, bufferFactory) - val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } - StructureMatrix(context, structure) - } - } + override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix = + if (rows == rowNum && columns == colNum) { + val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) } + StructureMatrix(this, structure) + } else { + val context = StructureMatrixSpace(rows, columns, ring, bufferFactory) + val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } + StructureMatrix(context, structure) + } override fun point(size: Int, initializer: (Int) -> T): Point = bufferFactory(size, initializer) } -data class StructureMatrix>(override val context: StructureMatrixSpace, val structure: NDStructure) : Matrix { +data class StructureMatrix>( + override val context: StructureMatrixSpace, + val structure: NDStructure +) : Matrix { + init { if (structure.shape.size != 2 || structure.shape[0] != context.rowNum || structure.shape[1] != context.colNum) { error("Dimension mismatch for structure, (${context.rowNum}, ${context.colNum}) expected, but ${structure.shape} found") diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt index 756f85959..37de60d7b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/Vector.kt @@ -33,9 +33,8 @@ interface VectorSpace> : Space> { /** * Non-boxing double vector space */ - fun real(size: Int): BufferVectorSpace { - return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) } - } + fun real(size: Int): BufferVectorSpace = + realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) } /** * A structured vector space with custom buffer @@ -69,16 +68,12 @@ interface Vector> : SpaceElement, VectorSpace Double) = VectorSpace.real(size).produce(initializer) - fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce{elements[it]} + fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce { elements[it] } } } -data class BufferVectorSpace>( - override val size: Int, - override val space: S, - val bufferFactory: BufferFactory -) : VectorSpace { +data class BufferVectorSpace>(override val size: Int, override val space: S, val bufferFactory: BufferFactory) : VectorSpace { override fun produce(initializer: (Int) -> T): Vector = BufferVector(this, bufferFactory(size, initializer)) } @@ -91,9 +86,7 @@ data class BufferVector>(override val context: VectorSpace } } - override fun get(index: Int): T { - return buffer[index] - } + override fun get(index: Int): T = buffer[index] override val self: BufferVector get() = this diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Cumulative.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Cumulative.kt index 4d4f8ced6..a3f5cfbbb 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Cumulative.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Cumulative.kt @@ -32,28 +32,28 @@ fun List.cumulative(initial: R, operation: (T, R) -> R): List = thi //Cumulative sum @JvmName("cumulativeSumOfDouble") -fun Iterable.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element} +fun Iterable.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Iterable.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element} +fun Iterable.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Iterable.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element} +fun Iterable.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } @JvmName("cumulativeSumOfDouble") -fun Sequence.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element} +fun Sequence.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun Sequence.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element} +fun Sequence.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun Sequence.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element} +fun Sequence.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } @JvmName("cumulativeSumOfDouble") -fun List.cumulativeSum() = this.cumulative(0.0){ element, sum -> sum + element} +fun List.cumulativeSum() = this.cumulative(0.0) { element, sum -> sum + element } @JvmName("cumulativeSumOfInt") -fun List.cumulativeSum() = this.cumulative(0){ element, sum -> sum + element} +fun List.cumulativeSum() = this.cumulative(0) { element, sum -> sum + element } @JvmName("cumulativeSumOfLong") -fun List.cumulativeSum() = this.cumulative(0L){ element, sum -> sum + element} \ No newline at end of file +fun List.cumulativeSum() = this.cumulative(0L) { element, sum -> sum + element } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt index c16b03608..90ce5da68 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/misc/Grids.kt @@ -8,30 +8,29 @@ package scientifik.kmath.misc * * If step is negative, the same goes from upper boundary downwards */ -fun ClosedFloatingPointRange.toSequence(step: Double): Sequence { - return when { - step == 0.0 -> error("Zero step in double progression") - step > 0 -> sequence { - var current = start - while (current <= endInclusive) { - yield(current) - current += step +fun ClosedFloatingPointRange.toSequence(step: Double): Sequence = + when { + step == 0.0 -> error("Zero step in double progression") + step > 0 -> sequence { + var current = start + while (current <= endInclusive) { + yield(current) + current += step + } + } + else -> sequence { + var current = endInclusive + while (current >= start) { + yield(current) + current += step + } } } - else -> sequence { - var current = endInclusive - while (current >= start) { - yield(current) - current += step - } - } - } -} /** * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] */ fun ClosedFloatingPointRange.toGrid(numPoints: Int): DoubleArray { - if (numPoints < 2) error("Can't generic grid with less than two points") + if (numPoints < 2) error("Can't create generic grid with less than two points") return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt index 51f3e75b3..4f5b93552 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Algebra.kt @@ -53,6 +53,7 @@ interface Space { //TODO move to external extensions when they are available fun Iterable.sum(): T = fold(zero) { left, right -> left + right } + fun Sequence.sum(): T = fold(zero) { left, right -> left + right } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt index eabae8ea4..ca6a48b12 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Fields.kt @@ -79,11 +79,11 @@ object DoubleField : ExtendedField, Norm { /** * A field for double without boxing. Does not produce appropriate field element */ -object IntField : Field{ +object IntField : Field { override val zero: Int = 0 override fun add(a: Int, b: Int): Int = a + b override fun multiply(a: Int, b: Int): Int = a * b - override fun multiply(a: Int, k: Double): Int = (k*a).toInt() + override fun multiply(a: Int, k: Double): Int = (k * a).toInt() override val one: Int = 1 override fun divide(a: Int, b: Int): Int = a / b } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt index ebe2a67cf..d0ae6a87b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/BufferNDField.kt @@ -2,12 +2,15 @@ package scientifik.kmath.structures import scientifik.kmath.operations.Field -open class BufferNDField>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory) : NDField { +open class BufferNDField>( + final override val shape: IntArray, + final override val field: F, + val bufferFactory: BufferFactory +) : NDField { val strides = DefaultStrides(shape) - override fun produce(initializer: F.(IntArray) -> T): BufferNDElement { - return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }) - } + override fun produce(initializer: F.(IntArray) -> T): BufferNDElement = + BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }) open fun produceBuffered(initializer: F.(Int) -> T) = BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) }) @@ -31,7 +34,10 @@ open class BufferNDField>(final override val shape: IntArray, fi // } } -class BufferNDElement>(override val context: BufferNDField, val buffer: Buffer) : NDElement { +class BufferNDElement>( + override val context: BufferNDField, + val buffer: Buffer +) : NDElement { override val self: NDStructure get() = this override val shape: IntArray get() = context.shape @@ -60,7 +66,7 @@ operator fun > BufferNDElement.plus(arg: T) = /** * Subtraction operation between [BufferNDElement] and single element */ -operator fun > BufferNDElement.minus(arg: T) = +operator fun > BufferNDElement.minus(arg: T) = context.produceBuffered { i -> buffer[i] - arg } /* prod and div */ @@ -68,11 +74,11 @@ operator fun > BufferNDElement.minus(arg: T) = /** * Product operation for [BufferNDElement] and single element */ -operator fun > BufferNDElement.times(arg: T) = +operator fun > BufferNDElement.times(arg: T) = context.produceBuffered { i -> buffer[i] * arg } /** * Division operation between [BufferNDElement] and single element */ -operator fun > BufferNDElement.div(arg: T) = +operator fun > BufferNDElement.div(arg: T) = context.produceBuffered { i -> buffer[i] / arg } \ No newline at end of file diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt index cffc2bea0..74e7af8d9 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/Buffers.kt @@ -139,14 +139,13 @@ inline fun boxingBuffer(size: Int, initializer: (Int) -> T): Buffer = Lis * Create most appropriate immutable buffer for given type avoiding boxing wherever possible */ @Suppress("UNCHECKED_CAST") -inline fun inlineBuffer(size: Int, initializer: (Int) -> T): Buffer { - return when (T::class) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer - Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer - Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer - else -> boxingBuffer(size, initializer) - } -} +inline fun inlineBuffer(size: Int, initializer: (Int) -> T): Buffer = + when (T::class) { + Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as Buffer + Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as Buffer + Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as Buffer + else -> boxingBuffer(size, initializer) + } /** * Create a boxing mutable buffer of given type @@ -157,14 +156,13 @@ inline fun boxingMutableBuffer(size: Int, initializer: (Int) -> T): Mu * Create most appropriate mutable buffer for given type avoiding boxing wherever possible */ @Suppress("UNCHECKED_CAST") -inline fun inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer { - return when (T::class) { - Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer - Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer - Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer - else -> boxingMutableBuffer(size, initializer) - } -} +inline fun inlineMutableBuffer(size: Int, initializer: (Int) -> T): MutableBuffer = + when (T::class) { + Double::class -> DoubleBuffer(DoubleArray(size) { initializer(it) as Double }) as MutableBuffer + Int::class -> IntBuffer(IntArray(size) { initializer(it) as Int }) as MutableBuffer + Long::class -> LongBuffer(LongArray(size) { initializer(it) as Long }) as MutableBuffer + else -> boxingMutableBuffer(size, initializer) + } typealias BufferFactory = (Int, (Int) -> T) -> Buffer typealias MutableBufferFactory = (Int, (Int) -> T) -> MutableBuffer diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index d274f92bd..cf6a1ca0c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -23,25 +23,16 @@ inline class ExtendedNDFieldWrapper>(private val n override fun produce(initializer: F.(IntArray) -> T): NDElement = ndField.produce(initializer) - override fun power(arg: NDStructure, pow: Double): NDElement { - return produce { with(field) { power(arg[it], pow) } } - } + override fun power(arg: NDStructure, pow: Double): NDElement = + produce { with(field) { power(arg[it], pow) } } - override fun exp(arg: NDStructure): NDElement { - return produce { with(field) { exp(arg[it]) } } - } + override fun exp(arg: NDStructure): NDElement = produce { with(field) { exp(arg[it]) } } - override fun ln(arg: NDStructure): NDElement { - return produce { with(field) { ln(arg[it]) } } - } + override fun ln(arg: NDStructure): NDElement = produce { with(field) { ln(arg[it]) } } - override fun sin(arg: NDStructure): NDElement { - return produce { with(field) { sin(arg[it]) } } - } + override fun sin(arg: NDStructure): NDElement = produce { with(field) { sin(arg[it]) } } - override fun cos(arg: NDStructure): NDElement { - return produce { with(field) { cos(arg[it]) } } - } + override fun cos(arg: NDStructure): NDElement = produce { with(field) { cos(arg[it]) } } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index 4fd6c3ee5..1fb0328b2 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -11,8 +11,8 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run /** * Field for n-dimensional arrays. - * @param shape - the list of dimensions of the array - * @param field - operations field defined on individual array element + * @property shape - the list of dimensions of the array + * @property field - operations field defined on individual array element * @param T the type of the element contained in NDArray */ interface NDField> : Field> { @@ -33,13 +33,8 @@ interface NDField> : Field> { /** * Check the shape of given NDArray and throw exception if it does not coincide with shape of the field */ - fun checkShape(vararg elements: NDStructure) { - elements.forEach { - if (!shape.contentEquals(it.shape)) { - throw ShapeMismatchException(shape, it.shape) - } - } - } + fun checkShape(vararg elements: NDStructure) = + elements.forEach { if (!shape.contentEquals(it.shape)) throw ShapeMismatchException(shape, it.shape) } /** * Element-by-element addition @@ -97,21 +92,17 @@ interface NDElement> : FieldElement, NDField Double = { 0.0 }): NDElement { - return NDField.real(shape).produce(initializer) - } + fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement = + NDField.real(shape).produce(initializer) - fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement { - return real(intArrayOf(dim)) { initializer(it[0]) } - } + fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement = + real(intArrayOf(dim)) { initializer(it[0]) } - fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement { - return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - } + fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement = + real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) } - fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement { - return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } - } + fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement = + real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } // inline fun real(shape: IntArray, block: ExtendedNDField.() -> NDStructure): NDElement { // val field = NDField.real(shape) @@ -121,13 +112,11 @@ interface NDElement> : FieldElement, NDField> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement { - return NDField.generic(shape, field).produce(initializer) - } + fun > generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement = + NDField.generic(shape, field).produce(initializer) - inline fun > inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement { - return NDField.inline(shape, field).produce(initializer) - } + inline fun > inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement = + NDField.inline(shape, field).produce(initializer) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 765d7148b..6c5d81713 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -19,11 +19,8 @@ interface MutableNDStructure : NDStructure { operator fun set(index: IntArray, value: T) } -fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T) { - elements().forEach { (index, oldValue) -> - this[index] = action(index, oldValue) - } -} +fun MutableNDStructure.mapInPlace(action: (IntArray, T) -> T) = + elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } /** * A way to convert ND index to linear one and back @@ -56,11 +53,11 @@ interface Strides { /** * Iterate over ND indices in a natural order + * + * TODO: introduce a fast way to calculate index of the next element? */ - fun indices(): Sequence { - //TODO introduce a fast way to calculate index of the next element? - return (0 until linearSize).asSequence().map { index(it) } - } + fun indices(): Sequence = + (0 until linearSize).asSequence().map { index(it) } } class DefaultStrides private constructor(override val shape: IntArray) : Strides { @@ -128,25 +125,21 @@ abstract class GenericNDStructure> : NDStructure { /** * Boxing generic [NDStructure] */ -class BufferNDStructure( - override val strides: Strides, - override val buffer: Buffer -) : GenericNDStructure>() { - +class BufferNDStructure(override val strides: Strides, + override val buffer: Buffer) : GenericNDStructure>() { init { if (strides.linearSize != buffer.size) { error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}") } } - override fun equals(other: Any?): Boolean { - return when { - this === other -> true - other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer) - other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] } - else -> false - } - } + override fun equals(other: Any?): Boolean = + when { + this === other -> true + other is BufferNDStructure<*> -> this.strides == other.strides && this.buffer.contentEquals(other.buffer) + other is NDStructure<*> -> elements().all { (index, value) -> value == other[index] } + else -> false + } override fun hashCode(): Int { var result = strides.hashCode() @@ -158,14 +151,13 @@ class BufferNDStructure( /** * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] */ -inline fun NDStructure.map(factory: BufferFactory = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure { - return if (this is BufferNDStructure) { - BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) - } else { - val strides = DefaultStrides(shape) - BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) - } -} +inline fun NDStructure.map(factory: BufferFactory = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure = + if (this is BufferNDStructure) { + BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) + } else { + val strides = DefaultStrides(shape) + BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) + } /** * Create a NDStructure with explicit buffer factory diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 7705c710e..f30460f5a 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -9,14 +9,13 @@ class RealNDField(shape: IntArray) : BufferNDField(shape, D /** * Inline map an NDStructure to */ - private inline fun NDStructure.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement { - return if (this is BufferNDElement) { - val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } - BufferNDElement(this@RealNDField, DoubleBuffer(array)) - } else { - produce { index -> DoubleField.operation(get(index)) } - } - } + private inline fun NDStructure.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement = + if (this is BufferNDElement) { + val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } + BufferNDElement(this@RealNDField, DoubleBuffer(array)) + } else { + produce { index -> DoubleField.operation(get(index)) } + } @Suppress("OVERRIDE_BY_INLINE") @@ -58,16 +57,12 @@ inline fun BufferNDField.produceInline(crossinline initiali operator fun Function1.invoke(ndElement: RealNDElement) = ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } -/* plus and minus */ - /** * Summation operation for [BufferNDElement] and single element */ -operator fun RealNDElement.plus(arg: Double) = - context.produceInline { i -> buffer[i] + arg } +operator fun RealNDElement.plus(arg: Double) = context.produceInline { i -> buffer[i] + arg } /** * Subtraction operation between [BufferNDElement] and single element */ -operator fun RealNDElement.minus(arg: Double) = - context.produceInline { i -> buffer[i] - arg } \ No newline at end of file +operator fun RealNDElement.minus(arg: Double) = context.produceInline { i -> buffer[i] - arg } \ No newline at end of file