Cached strides creation
This commit is contained in:
parent
bb64012992
commit
280bc9e507
@ -4,8 +4,8 @@ import scientifik.kmath.operations.DoubleField
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.structures.MutableNDStructure
|
||||
import scientifik.kmath.structures.NDStructure
|
||||
import scientifik.kmath.structures.genericNdStructure
|
||||
import scientifik.kmath.structures.get
|
||||
import scientifik.kmath.structures.mutableNdStructure
|
||||
import kotlin.math.absoluteValue
|
||||
|
||||
/**
|
||||
@ -107,7 +107,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
|
||||
val m = matrix.columns
|
||||
val pivot = IntArray(matrix.rows)
|
||||
//TODO fix performance
|
||||
val lu: MutableNDStructure<T> = genericNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] }
|
||||
val lu: MutableNDStructure<T> = mutableNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] }
|
||||
|
||||
|
||||
with(matrix.context.field) {
|
||||
|
@ -15,7 +15,7 @@ class ExtendedNDField<N : Any, F : ExtendedField<N>>(shape: IntArray, field: F)
|
||||
ExponentialOperations<NDElement<N, F>> {
|
||||
|
||||
override fun produceStructure(initializer: F.(IntArray) -> N): NDStructure<N> {
|
||||
return genericNdStructure(shape) { field.initializer(it) }
|
||||
return ndStructure(shape) { field.initializer(it) }
|
||||
}
|
||||
|
||||
override fun power(arg: NDElement<N, F>, pow: Double): NDElement<N, F> {
|
||||
|
@ -167,7 +167,7 @@ operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = tr
|
||||
}
|
||||
|
||||
class GenericNDField<T : Any, F : Field<T>>(shape: IntArray, field: F) : NDField<T, F>(shape, field) {
|
||||
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = genericNdStructure(shape) { field.initializer(it) }
|
||||
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = ndStructure(shape) { field.initializer(it) }
|
||||
}
|
||||
|
||||
//typealias NDFieldFactory<T> = (IntArray)->NDField<T>
|
||||
@ -195,11 +195,6 @@ object NDArrays {
|
||||
inline fun produceReal(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDElement<Double, DoubleField>) =
|
||||
ExtendedNDField(shape, DoubleField).run(block)
|
||||
|
||||
// /**
|
||||
// * Simple boxing NDField
|
||||
// */
|
||||
// fun <T : Any> fieldFactory(field: Field<T>): NDFieldFactory<T> = { shape -> GenericNDField(shape, field) }
|
||||
|
||||
/**
|
||||
* Simple boxing NDArray
|
||||
*/
|
||||
|
@ -63,7 +63,7 @@ interface Strides {
|
||||
}
|
||||
}
|
||||
|
||||
class DefaultStrides(override val shape: IntArray) : Strides {
|
||||
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
@ -101,6 +101,15 @@ class DefaultStrides(override val shape: IntArray) : Strides {
|
||||
|
||||
override val linearSize: Int
|
||||
get() = strides[shape.size]
|
||||
|
||||
companion object {
|
||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||
|
||||
/**
|
||||
* Cached builder for default strides
|
||||
*/
|
||||
operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||
}
|
||||
}
|
||||
|
||||
abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
|
||||
@ -131,13 +140,20 @@ class BufferNDStructure<T>(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a most suitable nd-structure avoiding boxing if possible using given strides.
|
||||
*
|
||||
* Strides should be reused if possible
|
||||
*/
|
||||
inline fun <reified T : Any> ndStructure(strides: Strides, noinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure<T>(strides, buffer(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
/**
|
||||
* Create a most suitable nd-structure avoiding boxing if possible using default strides with given shape
|
||||
*/
|
||||
inline fun <reified T : Any> ndStructure(shape: IntArray, noinline initializer: (IntArray) -> T) =
|
||||
ndStructure(DefaultStrides(shape), initializer)
|
||||
|
||||
|
||||
/**
|
||||
* Mutable ND buffer based on linear [Buffer]
|
||||
*/
|
||||
@ -156,7 +172,7 @@ class MutableBufferNDStructure<T>(
|
||||
}
|
||||
|
||||
/**
|
||||
* Create optimized mutable structure for given type
|
||||
* The same as [ndStructure], but mutable
|
||||
*/
|
||||
inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initializer: (IntArray) -> T) =
|
||||
MutableBufferNDStructure(strides, mutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
@ -164,16 +180,16 @@ inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initi
|
||||
inline fun <reified T : Any> mutableNdStructure(shape: IntArray, noinline initializer: (IntArray) -> T) =
|
||||
mutableNdStructure(DefaultStrides(shape), initializer)
|
||||
|
||||
/**
|
||||
* Create universal mutable structure
|
||||
*/
|
||||
fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure<T> {
|
||||
val strides = DefaultStrides(shape)
|
||||
val sequence = sequence {
|
||||
strides.indices().forEach {
|
||||
yield(initializer(it))
|
||||
}
|
||||
}
|
||||
val buffer = MutableListBuffer(sequence.toMutableList())
|
||||
return MutableBufferNDStructure(strides, buffer)
|
||||
}
|
||||
///**
|
||||
// * Create universal mutable structure
|
||||
// */
|
||||
//fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure<T> {
|
||||
// val strides = DefaultStrides(shape)
|
||||
// val sequence = sequence {
|
||||
// strides.indices().forEach {
|
||||
// yield(initializer(it))
|
||||
// }
|
||||
// }
|
||||
// val buffer = MutableListBuffer(sequence.toMutableList())
|
||||
// return MutableBufferNDStructure(strides, buffer)
|
||||
//}
|
||||
|
Loading…
Reference in New Issue
Block a user