From f28f036fec0c6fa709b9a3337d247f9cf85a68ad Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Fri, 3 Aug 2018 09:04:39 +0300 Subject: [PATCH] Added operations with external functions and elements on NDArray. Switched to kotlin 1.2.60 --- build.gradle | 2 +- .../scientifik/kmath/structures/NDArray.kt | 81 ++++++++++++++++++- .../kmath/structures/RealNDFieldTest.kt | 20 +++-- 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/build.gradle b/build.gradle index 226cff021..bce9c1492 100644 --- a/build.gradle +++ b/build.gradle @@ -1,5 +1,5 @@ buildscript { - ext.kotlin_version = '1.2.51' + ext.kotlin_version = '1.2.60' repositories { mavenCentral() diff --git a/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt b/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt index 4ae164fcf..5acac4ff1 100644 --- a/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt +++ b/common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt @@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: List, val actual: List) : R * @param field - operations field defined on individual array element * @param T the type of the element contained in NDArray */ -abstract class NDField(val shape: List, private val field: Field) : Field> { +abstract class NDField(val shape: List, val field: Field) : Field> { /** * Create new instance of NDArray using field shape and given initializer @@ -74,7 +74,7 @@ abstract class NDField(val shape: List, private val field: Field) : F } -interface NDArray : FieldElement, NDField>, Iterable, T>> { +interface NDArray : FieldElement, NDField> { /** * The list of dimensions of this NDArray @@ -97,14 +97,14 @@ interface NDArray : FieldElement, NDField>, Iterable, T>> { + operator fun iterator(): Iterator, T>> { return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator() } /** * Generate new NDArray, using given transformation for each element */ - fun transform(action: (List, T) -> T): NDArray = (context as NDField).produce { action(it, this[it]) } + fun transform(action: (List, T) -> T): NDArray = context.produce { action(it, this[it]) } companion object { /** @@ -125,6 +125,79 @@ interface NDArray : FieldElement, NDField>, Iterable Function1.invoke(ndArray: NDArray): NDArray = ndArray.transform { _, value -> this(value) } + +/* plus and minus */ + +/** + * Summation operation for [NDArray] and single element + */ +operator fun NDArray.plus(arg: T): NDArray = transform { _, value -> + with(context.field){ + arg + value + } +} + +/** + * Reverse sum operation + */ +operator fun T.plus(arg: NDArray): NDArray = arg + this + +/** + * Subtraction operation between [NDArray] and single element + */ +operator fun NDArray.minus(arg: T): NDArray = transform { _, value -> + with(context.field){ + arg - value + } +} + +/** + * Reverse minus operation + */ +operator fun T.minus(arg: NDArray): NDArray = arg.transform { _, value -> + with(arg.context.field){ + this@minus - value + } +} + +/* prod and div */ + +/** + * Product operation for [NDArray] and single element + */ +operator fun NDArray.times(arg: T): NDArray = transform { _, value -> + with(context.field){ + arg * value + } +} + +/** + * Reverse product operation + */ +operator fun T.times(arg: NDArray): NDArray = arg * this + +/** + * Division operation between [NDArray] and single element + */ +operator fun NDArray.div(arg: T): NDArray = transform { _, value -> + with(context.field){ + arg / value + } +} + +/** + * Reverse division operation + */ +operator fun T.div(arg: NDArray): NDArray = arg.transform { _, value -> + with(arg.context.field){ + this@div/ value + } +} + /** * Create a platform-specific NDArray of doubles */ diff --git a/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt b/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt index 00e1fd7b4..9bed502f9 100644 --- a/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt +++ b/jvm/src/test/kotlin/scientifik/kmath/structures/RealNDFieldTest.kt @@ -1,6 +1,7 @@ package scientifik.kmath.structures import org.junit.Assert.assertEquals +import kotlin.math.pow import kotlin.test.Test class RealNDFieldTest { @@ -14,8 +15,8 @@ class RealNDFieldTest { } @Test - fun testProduct(){ - val product = array1*array2 + fun testProduct() { + val product = array1 * array2 assertEquals(0.0, product[2, 2], 0.1) } @@ -24,11 +25,18 @@ class RealNDFieldTest { val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() } - for(i in 0..2){ - for(j in 0..2){ - val expected= (i * 10 + j).toDouble() - assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1) + for (i in 0..2) { + for (j in 0..2) { + val expected = (i * 10 + j).toDouble() + assertEquals("Error at index [$i, $j]", expected, array[i, j], 0.1) } } } + + @Test + fun testExternalFunction() { + val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 } + val result = function(array1) + 1.0 + assertEquals(10.0, result[1,1],0.01) + } }