Added operations with external functions and elements on NDArray. Switched to kotlin 1.2.60

This commit is contained in:
Alexander Nozik 2018-08-03 09:04:39 +03:00
parent 0d1825a044
commit f28f036fec
3 changed files with 92 additions and 11 deletions

View File

@ -1,5 +1,5 @@
buildscript { buildscript {
ext.kotlin_version = '1.2.51' ext.kotlin_version = '1.2.60'
repositories { repositories {
mavenCentral() mavenCentral()

View File

@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
* @param field - operations field defined on individual array element * @param field - operations field defined on individual array element
* @param T the type of the element contained in NDArray * @param T the type of the element contained in NDArray
*/ */
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> { abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
/** /**
* Create new instance of NDArray using field shape and given initializer * Create new instance of NDArray using field shape and given initializer
@ -74,7 +74,7 @@ abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : F
} }
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> { interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
/** /**
* The list of dimensions of this NDArray * The list of dimensions of this NDArray
@ -97,14 +97,14 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
return get(*index.toIntArray()) return get(*index.toIntArray())
} }
override operator fun iterator(): Iterator<Pair<List<Int>, T>> { operator fun iterator(): Iterator<Pair<List<Int>, T>> {
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator() return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
} }
/** /**
* Generate new NDArray, using given transformation for each element * Generate new NDArray, using given transformation for each element
*/ */
fun transform(action: (List<Int>, T) -> T): NDArray<T> = (context as NDField<T>).produce { action(it, this[it]) } fun transform(action: (List<Int>, T) -> T): NDArray<T> = context.produce { action(it, this[it]) }
companion object { companion object {
/** /**
@ -125,6 +125,79 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
} }
} }
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy
*/
operator fun <T> Function1<T, T>.invoke(ndArray: NDArray<T>): NDArray<T> = ndArray.transform { _, value -> this(value) }
/* plus and minus */
/**
* Summation operation for [NDArray] and single element
*/
operator fun <T> NDArray<T>.plus(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg + value
}
}
/**
* Reverse sum operation
*/
operator fun <T> T.plus(arg: NDArray<T>): NDArray<T> = arg + this
/**
* Subtraction operation between [NDArray] and single element
*/
operator fun <T> NDArray<T>.minus(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg - value
}
}
/**
* Reverse minus operation
*/
operator fun <T> T.minus(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
with(arg.context.field){
this@minus - value
}
}
/* prod and div */
/**
* Product operation for [NDArray] and single element
*/
operator fun <T> NDArray<T>.times(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg * value
}
}
/**
* Reverse product operation
*/
operator fun <T> T.times(arg: NDArray<T>): NDArray<T> = arg * this
/**
* Division operation between [NDArray] and single element
*/
operator fun <T> NDArray<T>.div(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg / value
}
}
/**
* Reverse division operation
*/
operator fun <T> T.div(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
with(arg.context.field){
this@div/ value
}
}
/** /**
* Create a platform-specific NDArray of doubles * Create a platform-specific NDArray of doubles
*/ */

View File

@ -1,6 +1,7 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import kotlin.math.pow
import kotlin.test.Test import kotlin.test.Test
class RealNDFieldTest { class RealNDFieldTest {
@ -31,4 +32,11 @@ class RealNDFieldTest {
} }
} }
} }
@Test
fun testExternalFunction() {
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
val result = function(array1) + 1.0
assertEquals(10.0, result[1,1],0.01)
}
} }