Added operations with external functions and elements on NDArray. Switched to kotlin 1.2.60

This commit is contained in:
Alexander Nozik 2018-08-03 09:04:39 +03:00
parent 0d1825a044
commit f28f036fec
3 changed files with 92 additions and 11 deletions

View File

@ -1,5 +1,5 @@
buildscript {
ext.kotlin_version = '1.2.51'
ext.kotlin_version = '1.2.60'
repositories {
mavenCentral()

View File

@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
* @param field - operations field defined on individual array element
* @param T the type of the element contained in NDArray
*/
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> {
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
/**
* Create new instance of NDArray using field shape and given initializer
@ -74,7 +74,7 @@ abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : F
}
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> {
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
/**
* The list of dimensions of this NDArray
@ -97,14 +97,14 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
return get(*index.toIntArray())
}
override operator fun iterator(): Iterator<Pair<List<Int>, T>> {
operator fun iterator(): Iterator<Pair<List<Int>, T>> {
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
}
/**
* Generate new NDArray, using given transformation for each element
*/
fun transform(action: (List<Int>, T) -> T): NDArray<T> = (context as NDField<T>).produce { action(it, this[it]) }
fun transform(action: (List<Int>, T) -> T): NDArray<T> = context.produce { action(it, this[it]) }
companion object {
/**
@ -125,6 +125,79 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
}
}
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy
*/
operator fun <T> Function1<T, T>.invoke(ndArray: NDArray<T>): NDArray<T> = ndArray.transform { _, value -> this(value) }
/* plus and minus */
/**
* Summation operation for [NDArray] and single element
*/
operator fun <T> NDArray<T>.plus(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg + value
}
}
/**
* Reverse sum operation
*/
operator fun <T> T.plus(arg: NDArray<T>): NDArray<T> = arg + this
/**
* Subtraction operation between [NDArray] and single element
*/
operator fun <T> NDArray<T>.minus(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg - value
}
}
/**
* Reverse minus operation
*/
operator fun <T> T.minus(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
with(arg.context.field){
this@minus - value
}
}
/* prod and div */
/**
* Product operation for [NDArray] and single element
*/
operator fun <T> NDArray<T>.times(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg * value
}
}
/**
* Reverse product operation
*/
operator fun <T> T.times(arg: NDArray<T>): NDArray<T> = arg * this
/**
* Division operation between [NDArray] and single element
*/
operator fun <T> NDArray<T>.div(arg: T): NDArray<T> = transform { _, value ->
with(context.field){
arg / value
}
}
/**
* Reverse division operation
*/
operator fun <T> T.div(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
with(arg.context.field){
this@div/ value
}
}
/**
* Create a platform-specific NDArray of doubles
*/

View File

@ -1,6 +1,7 @@
package scientifik.kmath.structures
import org.junit.Assert.assertEquals
import kotlin.math.pow
import kotlin.test.Test
class RealNDFieldTest {
@ -14,8 +15,8 @@ class RealNDFieldTest {
}
@Test
fun testProduct(){
val product = array1*array2
fun testProduct() {
val product = array1 * array2
assertEquals(0.0, product[2, 2], 0.1)
}
@ -24,11 +25,18 @@ class RealNDFieldTest {
val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() }
for(i in 0..2){
for(j in 0..2){
val expected= (i * 10 + j).toDouble()
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
for (i in 0..2) {
for (j in 0..2) {
val expected = (i * 10 + j).toDouble()
assertEquals("Error at index [$i, $j]", expected, array[i, j], 0.1)
}
}
}
@Test
fun testExternalFunction() {
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
val result = function(array1) + 1.0
assertEquals(10.0, result[1,1],0.01)
}
}