Cached strides creation

This commit is contained in:
Alexander Nozik 2018-12-17 12:44:03 +03:00
parent bb64012992
commit 280bc9e507
4 changed files with 37 additions and 26 deletions

View File

@ -4,8 +4,8 @@ import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDStructure
import scientifik.kmath.structures.genericNdStructure
import scientifik.kmath.structures.get
import scientifik.kmath.structures.mutableNdStructure
import kotlin.math.absoluteValue
/**
@ -107,7 +107,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
val m = matrix.columns
val pivot = IntArray(matrix.rows)
//TODO fix performance
val lu: MutableNDStructure<T> = genericNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] }
val lu: MutableNDStructure<T> = mutableNdStructure(intArrayOf(matrix.rows, matrix.columns)) { index -> matrix[index[0], index[1]] }
with(matrix.context.field) {

View File

@ -15,7 +15,7 @@ class ExtendedNDField<N : Any, F : ExtendedField<N>>(shape: IntArray, field: F)
ExponentialOperations<NDElement<N, F>> {
override fun produceStructure(initializer: F.(IntArray) -> N): NDStructure<N> {
return genericNdStructure(shape) { field.initializer(it) }
return ndStructure(shape) { field.initializer(it) }
}
override fun power(arg: NDElement<N, F>, pow: Double): NDElement<N, F> {

View File

@ -167,7 +167,7 @@ operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = tr
}
class GenericNDField<T : Any, F : Field<T>>(shape: IntArray, field: F) : NDField<T, F>(shape, field) {
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = genericNdStructure(shape) { field.initializer(it) }
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = ndStructure(shape) { field.initializer(it) }
}
//typealias NDFieldFactory<T> = (IntArray)->NDField<T>
@ -195,11 +195,6 @@ object NDArrays {
inline fun produceReal(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDElement<Double, DoubleField>) =
ExtendedNDField(shape, DoubleField).run(block)
// /**
// * Simple boxing NDField
// */
// fun <T : Any> fieldFactory(field: Field<T>): NDFieldFactory<T> = { shape -> GenericNDField(shape, field) }
/**
* Simple boxing NDArray
*/

View File

@ -63,7 +63,7 @@ interface Strides {
}
}
class DefaultStrides(override val shape: IntArray) : Strides {
class DefaultStrides private constructor(override val shape: IntArray) : Strides {
/**
* Strides for memory access
*/
@ -101,6 +101,15 @@ class DefaultStrides(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = strides[shape.size]
companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
operator fun invoke(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
}
}
abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
@ -112,7 +121,7 @@ abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
override val shape: IntArray
get() = strides.shape
override fun elements()=
override fun elements() =
strides.indices().map { it to this[it] }
}
@ -131,13 +140,20 @@ class BufferNDStructure<T>(
}
}
/**
* Create a most suitable nd-structure avoiding boxing if possible using given strides.
*
* Strides should be reused if possible
*/
inline fun <reified T : Any> ndStructure(strides: Strides, noinline initializer: (IntArray) -> T) =
BufferNDStructure<T>(strides, buffer(strides.linearSize) { i -> initializer(strides.index(i)) })
/**
* Create a most suitable nd-structure avoiding boxing if possible using default strides with given shape
*/
inline fun <reified T : Any> ndStructure(shape: IntArray, noinline initializer: (IntArray) -> T) =
ndStructure(DefaultStrides(shape), initializer)
/**
* Mutable ND buffer based on linear [Buffer]
*/
@ -156,7 +172,7 @@ class MutableBufferNDStructure<T>(
}
/**
* Create optimized mutable structure for given type
* The same as [ndStructure], but mutable
*/
inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initializer: (IntArray) -> T) =
MutableBufferNDStructure(strides, mutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
@ -164,16 +180,16 @@ inline fun <reified T : Any> mutableNdStructure(strides: Strides, noinline initi
inline fun <reified T : Any> mutableNdStructure(shape: IntArray, noinline initializer: (IntArray) -> T) =
mutableNdStructure(DefaultStrides(shape), initializer)
/**
* Create universal mutable structure
*/
fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure<T> {
val strides = DefaultStrides(shape)
val sequence = sequence {
strides.indices().forEach {
yield(initializer(it))
}
}
val buffer = MutableListBuffer(sequence.toMutableList())
return MutableBufferNDStructure(strides, buffer)
}
///**
// * Create universal mutable structure
// */
//fun <T> genericNdStructure(shape: IntArray, initializer: (IntArray) -> T): MutableBufferNDStructure<T> {
// val strides = DefaultStrides(shape)
// val sequence = sequence {
// strides.indices().forEach {
// yield(initializer(it))
// }
// }
// val buffer = MutableListBuffer(sequence.toMutableList())
// return MutableBufferNDStructure(strides, buffer)
//}