diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt index 6a94042b7..3a697cadd 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/linear/LUDecomposition.kt @@ -4,8 +4,8 @@ import scientifik.kmath.operations.DoubleField import scientifik.kmath.operations.Field import scientifik.kmath.structures.MutableNDStructure import scientifik.kmath.structures.NDStructure -import scientifik.kmath.structures.genericNdStructure import scientifik.kmath.structures.get +import scientifik.kmath.structures.mutableNdStructure import kotlin.math.absoluteValue /** @@ -107,7 +107,7 @@ abstract class LUDecomposition, F : Field>(val matrix: Matr val m = matrix.columns val pivot = IntArray(matrix.rows) //TODO fix performance - val lu: MutableNDStructure = genericNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] } + val lu: MutableNDStructure = mutableNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] } with(matrix.context.field) { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 93ffad5a3..ccdd35e5b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -15,7 +15,7 @@ class ExtendedNDField>(shape: IntArray, field: F) ExponentialOperations> { override fun produceStructure(initializer: F.(IntArray) -> N): NDStructure { - return genericNdStructure(shape) { field.initializer(it) } + return ndStructure(shape) { field.initializer(it) } } override fun power(arg: NDElement, pow: Double): NDElement { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt index cddac5801..1757e7d71 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDField.kt @@ -167,7 +167,7 @@ operator fun > NDElement.div(arg: T): NDElement = tr } class GenericNDField>(shape: IntArray, field: F) : NDField(shape, field) { - override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure = genericNdStructure(shape) { field.initializer(it) } + override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure = ndStructure(shape) { field.initializer(it) } } //typealias NDFieldFactory = (IntArray)->NDField @@ -195,11 +195,6 @@ object NDArrays { inline fun produceReal(shape: IntArray, block: ExtendedNDField.() -> NDElement) = ExtendedNDField(shape, DoubleField).run(block) -// /** -// * Simple boxing NDField -// */ -// fun fieldFactory(field: Field): NDFieldFactory = { shape -> GenericNDField(shape, field) } - /** * Simple boxing NDArray */ diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt index 6e159c9dd..f722cbfaa 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDStructure.kt @@ -63,7 +63,7 @@ interface Strides { } } -class DefaultStrides(override val shape: IntArray) : Strides { +class DefaultStrides private constructor(override val shape: IntArray) : Strides { /** * Strides for memory access */ @@ -101,6 +101,15 @@ class DefaultStrides(override val shape: IntArray) : Strides { override val linearSize: Int get() = strides[shape.size] + + companion object { + private val defaultStridesCache = HashMap() + + /** + * Cached builder for default strides + */ + operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) } + } } abstract class GenericNDStructure> : NDStructure { @@ -112,7 +121,7 @@ abstract class GenericNDStructure> : NDStructure { override val shape: IntArray get() = strides.shape - override fun elements()= + override fun elements() = strides.indices().map { it to this[it] } } @@ -131,13 +140,20 @@ class BufferNDStructure( } } +/** + * Create a most suitable nd-structure avoiding boxing if possible using given strides. + * + * Strides should be reused if possible + */ inline fun ndStructure(strides: Strides, noinline initializer: (IntArray) -> T) = BufferNDStructure(strides, buffer(strides.linearSize) { i -> initializer(strides.index(i)) }) +/** + * Create a most suitable nd-structure avoiding boxing if possible using default strides with given shape + */ inline fun ndStructure(shape: IntArray, noinline initializer: (IntArray) -> T) = ndStructure(DefaultStrides(shape), initializer) - /** * Mutable ND buffer based on linear [Buffer] */ @@ -156,7 +172,7 @@ class MutableBufferNDStructure( } /** - * Create optimized mutable structure for given type + * The same as [ndStructure], but mutable */ inline fun mutableNdStructure(strides: Strides, noinline initializer: (IntArray) -> T) = MutableBufferNDStructure(strides, mutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) @@ -164,16 +180,16 @@ inline fun mutableNdStructure(strides: Strides, noinline initi inline fun mutableNdStructure(shape: IntArray, noinline initializer: (IntArray) -> T) = mutableNdStructure(DefaultStrides(shape), initializer) -/** - * Create universal mutable structure - */ -fun genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure { - val strides = DefaultStrides(shape) - val sequence = sequence { - strides.indices().forEach { - yield(initializer(it)) - } - } - val buffer = MutableListBuffer(sequence.toMutableList()) - return MutableBufferNDStructure(strides, buffer) -} +///** +// * Create universal mutable structure +// */ +//fun genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure { +// val strides = DefaultStrides(shape) +// val sequence = sequence { +// strides.indices().forEach { +// yield(initializer(it)) +// } +// } +// val buffer = MutableListBuffer(sequence.toMutableList()) +// return MutableBufferNDStructure(strides, buffer) +//}