Added operations with external functions and elements on NDArray. Switched to kotlin 1.2.60
This commit is contained in:
parent
0d1825a044
commit
f28f036fec
@ -1,5 +1,5 @@
|
||||
buildscript {
|
||||
ext.kotlin_version = '1.2.51'
|
||||
ext.kotlin_version = '1.2.60'
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
|
@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
|
||||
* @param field - operations field defined on individual array element
|
||||
* @param T the type of the element contained in NDArray
|
||||
*/
|
||||
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> {
|
||||
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
||||
|
||||
/**
|
||||
* Create new instance of NDArray using field shape and given initializer
|
||||
@ -74,7 +74,7 @@ abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : F
|
||||
}
|
||||
|
||||
|
||||
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> {
|
||||
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
|
||||
|
||||
/**
|
||||
* The list of dimensions of this NDArray
|
||||
@ -97,14 +97,14 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
|
||||
return get(*index.toIntArray())
|
||||
}
|
||||
|
||||
override operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
||||
operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
||||
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate new NDArray, using given transformation for each element
|
||||
*/
|
||||
fun transform(action: (List<Int>, T) -> T): NDArray<T> = (context as NDField<T>).produce { action(it, this[it]) }
|
||||
fun transform(action: (List<Int>, T) -> T): NDArray<T> = context.produce { action(it, this[it]) }
|
||||
|
||||
companion object {
|
||||
/**
|
||||
@ -125,6 +125,79 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
*/
|
||||
operator fun <T> Function1<T, T>.invoke(ndArray: NDArray<T>): NDArray<T> = ndArray.transform { _, value -> this(value) }
|
||||
|
||||
/* plus and minus */
|
||||
|
||||
/**
|
||||
* Summation operation for [NDArray] and single element
|
||||
*/
|
||||
operator fun <T> NDArray<T>.plus(arg: T): NDArray<T> = transform { _, value ->
|
||||
with(context.field){
|
||||
arg + value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverse sum operation
|
||||
*/
|
||||
operator fun <T> T.plus(arg: NDArray<T>): NDArray<T> = arg + this
|
||||
|
||||
/**
|
||||
* Subtraction operation between [NDArray] and single element
|
||||
*/
|
||||
operator fun <T> NDArray<T>.minus(arg: T): NDArray<T> = transform { _, value ->
|
||||
with(context.field){
|
||||
arg - value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverse minus operation
|
||||
*/
|
||||
operator fun <T> T.minus(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
|
||||
with(arg.context.field){
|
||||
this@minus - value
|
||||
}
|
||||
}
|
||||
|
||||
/* prod and div */
|
||||
|
||||
/**
|
||||
* Product operation for [NDArray] and single element
|
||||
*/
|
||||
operator fun <T> NDArray<T>.times(arg: T): NDArray<T> = transform { _, value ->
|
||||
with(context.field){
|
||||
arg * value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverse product operation
|
||||
*/
|
||||
operator fun <T> T.times(arg: NDArray<T>): NDArray<T> = arg * this
|
||||
|
||||
/**
|
||||
* Division operation between [NDArray] and single element
|
||||
*/
|
||||
operator fun <T> NDArray<T>.div(arg: T): NDArray<T> = transform { _, value ->
|
||||
with(context.field){
|
||||
arg / value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reverse division operation
|
||||
*/
|
||||
operator fun <T> T.div(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
|
||||
with(arg.context.field){
|
||||
this@div/ value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a platform-specific NDArray of doubles
|
||||
*/
|
||||
|
@ -1,6 +1,7 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import org.junit.Assert.assertEquals
|
||||
import kotlin.math.pow
|
||||
import kotlin.test.Test
|
||||
|
||||
class RealNDFieldTest {
|
||||
@ -14,8 +15,8 @@ class RealNDFieldTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testProduct(){
|
||||
val product = array1*array2
|
||||
fun testProduct() {
|
||||
val product = array1 * array2
|
||||
assertEquals(0.0, product[2, 2], 0.1)
|
||||
}
|
||||
|
||||
@ -24,11 +25,18 @@ class RealNDFieldTest {
|
||||
|
||||
val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
||||
|
||||
for(i in 0..2){
|
||||
for(j in 0..2){
|
||||
val expected= (i * 10 + j).toDouble()
|
||||
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
|
||||
for (i in 0..2) {
|
||||
for (j in 0..2) {
|
||||
val expected = (i * 10 + j).toDouble()
|
||||
assertEquals("Error at index [$i, $j]", expected, array[i, j], 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testExternalFunction() {
|
||||
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
||||
val result = function(array1) + 1.0
|
||||
assertEquals(10.0, result[1,1],0.01)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user