Optimized mapping functions for NDElements
This commit is contained in:
parent
2b4419823b
commit
77bf8de4f1
@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
|
||||
/**
|
||||
* The base interface for all nd-algebra implementations
|
||||
* @param T the type of nd-structure element
|
||||
* @param C the type of the context
|
||||
* @param C the type of the element context
|
||||
* @param N the type of the structure
|
||||
*/
|
||||
interface NDAlgebra<T, C, N : NDStructure<T>> {
|
||||
@ -112,10 +112,13 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F,
|
||||
operator fun T.div(arg: N) = map(arg) { divide(it, this@div) }
|
||||
|
||||
companion object {
|
||||
|
||||
private val realNDFieldCache = HashMap<IntArray, RealNDField>()
|
||||
|
||||
/**
|
||||
* Create a nd-field for [Double] values
|
||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
||||
*/
|
||||
fun real(shape: IntArray) = RealNDField(shape)
|
||||
fun real(shape: IntArray) = realNDFieldCache.getOrPut(shape){RealNDField(shape)}
|
||||
|
||||
/**
|
||||
* Create a nd-field with boxing generic buffer
|
||||
|
@ -7,6 +7,9 @@ import scientifik.kmath.operations.Space
|
||||
|
||||
/**
|
||||
* The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types
|
||||
* @param T the type of the element of the structure
|
||||
* @param C the type of the context for the element
|
||||
* @param N the type of the underlying [NDStructure]
|
||||
*/
|
||||
interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
||||
|
||||
@ -16,9 +19,6 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
||||
|
||||
fun N.wrap(): NDElement<T, C, N>
|
||||
|
||||
fun mapIndexed(transform: C.(index: IntArray, T) -> T) = context.mapIndexed(unwrap(), transform).wrap()
|
||||
fun map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap()
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Create a optimized NDArray of doubles
|
||||
@ -61,10 +61,17 @@ interface NDElement<T, C, N : NDStructure<T>> : NDStructure<T> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.mapIndexed(transform: C.(index: IntArray, T) -> T) =
|
||||
context.mapIndexed(unwrap(), transform).wrap()
|
||||
|
||||
fun <T, C, N : NDStructure<T>> NDElement<T, C, N>.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap()
|
||||
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole [NDElement]
|
||||
*/
|
||||
operator fun <T, C> Function1<T, T>.invoke(ndElement: NDElement<T, C, *>) =
|
||||
operator fun <T, C, N : NDStructure<T>> Function1<T, T>.invoke(ndElement: NDElement<T, C, N>) =
|
||||
ndElement.map { value -> this@invoke(value) }
|
||||
|
||||
/* plus and minus */
|
||||
@ -72,13 +79,13 @@ operator fun <T, C> Function1<T, T>.invoke(ndElement: NDElement<T, C, *>) =
|
||||
/**
|
||||
* Summation operation for [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, S : Space<T>> NDElement<T, S, *>.plus(arg: T) =
|
||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.plus(arg: T) =
|
||||
map { value -> arg + value }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, S : Space<T>> NDElement<T, S, *>.minus(arg: T) =
|
||||
operator fun <T, S : Space<T>, N : NDStructure<T>> NDElement<T, S, N>.minus(arg: T) =
|
||||
map { value -> arg - value }
|
||||
|
||||
/* prod and div */
|
||||
@ -86,13 +93,13 @@ operator fun <T, S : Space<T>> NDElement<T, S, *>.minus(arg: T) =
|
||||
/**
|
||||
* Product operation for [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, R : Ring<T>> NDElement<T, R, *>.times(arg: T) =
|
||||
operator fun <T, R : Ring<T>, N : NDStructure<T>> NDElement<T, R, N>.times(arg: T) =
|
||||
map { value -> arg * value }
|
||||
|
||||
/**
|
||||
* Division operation between [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F, *>.div(arg: T) =
|
||||
operator fun <T, F : Field<T>, N : NDStructure<T>> NDElement<T, F, N>.div(arg: T) =
|
||||
map { value -> arg / value }
|
||||
|
||||
|
||||
|
@ -86,11 +86,25 @@ inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initiali
|
||||
return BufferedNDFieldElement(this, DoubleBuffer(array))
|
||||
}
|
||||
|
||||
/**
|
||||
* Map one [RealNDElement] using function with indexes
|
||||
*/
|
||||
inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) =
|
||||
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||
|
||||
/**
|
||||
* Map one [RealNDElement] using function without indexes
|
||||
*/
|
||||
inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement {
|
||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) }
|
||||
return BufferedNDFieldElement(context, DoubleBuffer(array))
|
||||
}
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
*/
|
||||
operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
||||
ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) }
|
||||
ndElement.map { this@invoke(it) }
|
||||
|
||||
|
||||
/* plus and minus */
|
||||
@ -99,10 +113,10 @@ operator fun Function1<Double, Double>.invoke(ndElement: RealNDElement) =
|
||||
* Summation operation for [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun RealNDElement.plus(arg: Double) =
|
||||
context.produceInline { i -> buffer[i] + arg }
|
||||
map { it + arg }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [BufferedNDElement] and single element
|
||||
*/
|
||||
operator fun RealNDElement.minus(arg: Double) =
|
||||
context.produceInline { i -> buffer[i] - arg }
|
||||
map { it - arg }
|
||||
|
@ -2,8 +2,7 @@ pluginManagement {
|
||||
repositories {
|
||||
mavenCentral()
|
||||
maven("https://plugins.gradle.org/m2/")
|
||||
maven { setUrl("https://dl.bintray.com/kotlin/kotlin-eap") }
|
||||
maven { setUrl("https://plugins.gradle.org/m2/") }
|
||||
maven ("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user