From 77bf8de4f1a5f8f686b3f2c0c71805eb641c67c3 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 23 Jan 2019 13:30:26 +0300 Subject: [PATCH] Optimized mapping functions for NDElements --- .../scientifik/kmath/structures/NDAlgebra.kt | 9 +++++--- .../scientifik/kmath/structures/NDElement.kt | 23 ++++++++++++------- .../kmath/structures/RealNDField.kt | 20 +++++++++++++--- settings.gradle.kts | 3 +-- 4 files changed, 39 insertions(+), 16 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt index 097d52723..7ea768c63 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDAlgebra.kt @@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run /** * The base interface for all nd-algebra implementations * @param T the type of nd-structure element - * @param C the type of the context + * @param C the type of the element context * @param N the type of the structure */ interface NDAlgebra> { @@ -112,10 +112,13 @@ interface NDField, N : NDStructure> : Field, NDRing() + /** - * Create a nd-field for [Double] values + * Create a nd-field for [Double] values or pull it from cache if it was created previously */ - fun real(shape: IntArray) = RealNDField(shape) + fun real(shape: IntArray) = realNDFieldCache.getOrPut(shape){RealNDField(shape)} /** * Create a nd-field with boxing generic buffer diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt index 0fdb53f07..c97f959f3 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/NDElement.kt @@ -7,6 +7,9 @@ import scientifik.kmath.operations.Space /** * The root for all [NDStructure] based algebra elements. Does not implement algebra element root because of problems with recursive self-types + * @param T the type of the element of the structure + * @param C the type of the context for the element + * @param N the type of the underlying [NDStructure] */ interface NDElement> : NDStructure { @@ -16,9 +19,6 @@ interface NDElement> : NDStructure { fun N.wrap(): NDElement - fun mapIndexed(transform: C.(index: IntArray, T) -> T) = context.mapIndexed(unwrap(), transform).wrap() - fun map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap() - companion object { /** * Create a optimized NDArray of doubles @@ -61,10 +61,17 @@ interface NDElement> : NDStructure { } } + +fun > NDElement.mapIndexed(transform: C.(index: IntArray, T) -> T) = + context.mapIndexed(unwrap(), transform).wrap() + +fun > NDElement.map(transform: C.(T) -> T) = context.map(unwrap(), transform).wrap() + + /** * Element by element application of any operation on elements to the whole [NDElement] */ -operator fun Function1.invoke(ndElement: NDElement) = +operator fun > Function1.invoke(ndElement: NDElement) = ndElement.map { value -> this@invoke(value) } /* plus and minus */ @@ -72,13 +79,13 @@ operator fun Function1.invoke(ndElement: NDElement) = /** * Summation operation for [NDElement] and single element */ -operator fun > NDElement.plus(arg: T) = +operator fun , N : NDStructure> NDElement.plus(arg: T) = map { value -> arg + value } /** * Subtraction operation between [NDElement] and single element */ -operator fun > NDElement.minus(arg: T) = +operator fun , N : NDStructure> NDElement.minus(arg: T) = map { value -> arg - value } /* prod and div */ @@ -86,13 +93,13 @@ operator fun > NDElement.minus(arg: T) = /** * Product operation for [NDElement] and single element */ -operator fun > NDElement.times(arg: T) = +operator fun , N : NDStructure> NDElement.times(arg: T) = map { value -> arg * value } /** * Division operation between [NDElement] and single element */ -operator fun > NDElement.div(arg: T) = +operator fun , N : NDStructure> NDElement.div(arg: T) = map { value -> arg / value } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index bc5832e1c..d652bb8a8 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -86,11 +86,25 @@ inline fun BufferedNDField.produceInline(crossinline initiali return BufferedNDFieldElement(this, DoubleBuffer(array)) } +/** + * Map one [RealNDElement] using function with indexes + */ +inline fun RealNDElement.mapIndexed(crossinline transform: RealField.(index: IntArray, Double) -> Double) = + context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) } + +/** + * Map one [RealNDElement] using function without indexes + */ +inline fun RealNDElement.map(crossinline transform: RealField.(Double) -> Double): RealNDElement { + val array = DoubleArray(strides.linearSize) { offset -> RealField.transform(buffer[offset]) } + return BufferedNDFieldElement(context, DoubleBuffer(array)) +} + /** * Element by element application of any operation on elements to the whole array. Just like in numpy */ operator fun Function1.invoke(ndElement: RealNDElement) = - ndElement.context.produceInline { i -> invoke(ndElement.buffer[i]) } + ndElement.map { this@invoke(it) } /* plus and minus */ @@ -99,10 +113,10 @@ operator fun Function1.invoke(ndElement: RealNDElement) = * Summation operation for [BufferedNDElement] and single element */ operator fun RealNDElement.plus(arg: Double) = - context.produceInline { i -> buffer[i] + arg } + map { it + arg } /** * Subtraction operation between [BufferedNDElement] and single element */ operator fun RealNDElement.minus(arg: Double) = - context.produceInline { i -> buffer[i] - arg } + map { it - arg } diff --git a/settings.gradle.kts b/settings.gradle.kts index a4464d01f..ca738647e 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -2,8 +2,7 @@ pluginManagement { repositories { mavenCentral() maven("https://plugins.gradle.org/m2/") - maven { setUrl("https://dl.bintray.com/kotlin/kotlin-eap") } - maven { setUrl("https://plugins.gradle.org/m2/") } + maven ("https://dl.bintray.com/kotlin/kotlin-eap") } }