Cached strides creation

This commit is contained in:
Alexander Nozik 2018-12-17 12:44:03 +03:00
parent bb64012992
commit 280bc9e507
4 changed files with 37 additions and 26 deletions

View File

@ -4,8 +4,8 @@ import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.structures.MutableNDStructure import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDStructure import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.genericNdStructure
import scientifik.kmath.structures.get import scientifik.kmath.structures.get
import scientifik.kmath.structures.mutableNdStructure
import kotlin.math.absoluteValue import kotlin.math.absoluteValue
/** /**
@ -107,7 +107,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
val m = matrix.columns val m = matrix.columns
val pivot = IntArray(matrix.rows) val pivot = IntArray(matrix.rows)
//TODO fix performance //TODO fix performance
val lu: MutableNDStructure<T> = genericNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] } val lu: MutableNDStructure<T> = mutableNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] }
with(matrix.context.field) { with(matrix.context.field) {

View File

@ -15,7 +15,7 @@ class ExtendedNDField<N : Any, F : ExtendedField<N>>(shape: IntArray, field: F)
ExponentialOperations<NDElement<N, F>> { ExponentialOperations<NDElement<N, F>> {
override fun produceStructure(initializer: F.(IntArray) -> N): NDStructure<N> { override fun produceStructure(initializer: F.(IntArray) -> N): NDStructure<N> {
return genericNdStructure(shape) { field.initializer(it) } return ndStructure(shape) { field.initializer(it) }
} }
override fun power(arg: NDElement<N, F>, pow: Double): NDElement<N, F> { override fun power(arg: NDElement<N, F>, pow: Double): NDElement<N, F> {

View File

@ -167,7 +167,7 @@ operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = tr
} }
class GenericNDField<T : Any, F : Field<T>>(shape: IntArray, field: F) : NDField<T, F>(shape, field) { class GenericNDField<T : Any, F : Field<T>>(shape: IntArray, field: F) : NDField<T, F>(shape, field) {
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = genericNdStructure(shape) { field.initializer(it) } override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = ndStructure(shape) { field.initializer(it) }
} }
//typealias NDFieldFactory<T> = (IntArray)->NDField<T> //typealias NDFieldFactory<T> = (IntArray)->NDField<T>
@ -195,11 +195,6 @@ object NDArrays {
inline fun produceReal(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDElement<Double, DoubleField>) = inline fun produceReal(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDElement<Double, DoubleField>) =
ExtendedNDField(shape, DoubleField).run(block) ExtendedNDField(shape, DoubleField).run(block)
// /**
// * Simple boxing NDField
// */
// fun <T : Any> fieldFactory(field: Field<T>): NDFieldFactory<T> = { shape -> GenericNDField(shape, field) }
/** /**
* Simple boxing NDArray * Simple boxing NDArray
*/ */

View File

@ -63,7 +63,7 @@ interface Strides {
} }
} }
class DefaultStrides(override val shape: IntArray) : Strides { class DefaultStrides private constructor(override val shape: IntArray) : Strides {
/** /**
* Strides for memory access * Strides for memory access
*/ */
@ -101,6 +101,15 @@ class DefaultStrides(override val shape: IntArray) : Strides {
override val linearSize: Int override val linearSize: Int
get() = strides[shape.size] get() = strides[shape.size]
companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
}
} }
abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> { abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
@ -112,7 +121,7 @@ abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
override val shape: IntArray override val shape: IntArray
get() = strides.shape get() = strides.shape
override fun elements()= override fun elements() =
strides.indices().map { it to this[it] } strides.indices().map { it to this[it] }
} }
@ -131,13 +140,20 @@ class BufferNDStructure<T>(
} }
} }
/**
* Create a most suitable nd-structure avoiding boxing if possible using given strides.
*
* Strides should be reused if possible
*/
inline fun <reified T : Any> ndStructure(strides: Strides, noinline initializer: (IntArray) -> T) = inline fun <reified T : Any> ndStructure(strides: Strides, noinline initializer: (IntArray) -> T) =
BufferNDStructure<T>(strides, buffer(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure<T>(strides, buffer(strides.linearSize) { i -> initializer(strides.index(i)) })
/**
* Create a most suitable nd-structure avoiding boxing if possible using default strides with given shape
*/
inline fun <reified T : Any> ndStructure(shape: IntArray, noinline initializer: (IntArray) -> T) = inline fun <reified T : Any> ndStructure(shape: IntArray, noinline initializer: (IntArray) -> T) =
ndStructure(DefaultStrides(shape), initializer) ndStructure(DefaultStrides(shape), initializer)
/** /**
* Mutable ND buffer based on linear [Buffer] * Mutable ND buffer based on linear [Buffer]
*/ */
@ -156,7 +172,7 @@ class MutableBufferNDStructure<T>(
} }
/** /**
* Create optimized mutable structure for given type * The same as [ndStructure], but mutable
*/ */
inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initializer: (IntArray) -> T) = inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initializer: (IntArray) -> T) =
MutableBufferNDStructure(strides, mutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) MutableBufferNDStructure(strides, mutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
@ -164,16 +180,16 @@ inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initi
inline fun <reified T : Any> mutableNdStructure(shape: IntArray, noinline initializer: (IntArray) -> T) = inline fun <reified T : Any> mutableNdStructure(shape: IntArray, noinline initializer: (IntArray) -> T) =
mutableNdStructure(DefaultStrides(shape), initializer) mutableNdStructure(DefaultStrides(shape), initializer)
/** ///**
* Create universal mutable structure // * Create universal mutable structure
*/ // */
fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure<T> { //fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure<T> {
val strides = DefaultStrides(shape) // val strides = DefaultStrides(shape)
val sequence = sequence { // val sequence = sequence {
strides.indices().forEach { // strides.indices().forEach {
yield(initializer(it)) // yield(initializer(it))
} // }
} // }
val buffer = MutableListBuffer(sequence.toMutableList()) // val buffer = MutableListBuffer(sequence.toMutableList())
return MutableBufferNDStructure(strides, buffer) // return MutableBufferNDStructure(strides, buffer)
} //}