forked from kscience/kmath
Added operations with external functions and elements on NDArray. Switched to kotlin 1.2.60
This commit is contained in:
parent
0d1825a044
commit
f28f036fec
@ -1,5 +1,5 @@
|
|||||||
buildscript {
|
buildscript {
|
||||||
ext.kotlin_version = '1.2.51'
|
ext.kotlin_version = '1.2.60'
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
@ -14,7 +14,7 @@ class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : R
|
|||||||
* @param field - operations field defined on individual array element
|
* @param field - operations field defined on individual array element
|
||||||
* @param T the type of the element contained in NDArray
|
* @param T the type of the element contained in NDArray
|
||||||
*/
|
*/
|
||||||
abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : Field<NDArray<T>> {
|
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create new instance of NDArray using field shape and given initializer
|
* Create new instance of NDArray using field shape and given initializer
|
||||||
@ -74,7 +74,7 @@ abstract class NDField<T>(val shape: List<Int>, private val field: Field<T>) : F
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<Int>, T>> {
|
interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The list of dimensions of this NDArray
|
* The list of dimensions of this NDArray
|
||||||
@ -97,14 +97,14 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
|
|||||||
return get(*index.toIntArray())
|
return get(*index.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
||||||
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
|
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate new NDArray, using given transformation for each element
|
* Generate new NDArray, using given transformation for each element
|
||||||
*/
|
*/
|
||||||
fun transform(action: (List<Int>, T) -> T): NDArray<T> = (context as NDField<T>).produce { action(it, this[it]) }
|
fun transform(action: (List<Int>, T) -> T): NDArray<T> = context.produce { action(it, this[it]) }
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
/**
|
/**
|
||||||
@ -125,6 +125,79 @@ interface NDArray<T> : FieldElement<NDArray<T>, NDField<T>>, Iterable<Pair<List<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||||
|
*/
|
||||||
|
operator fun <T> Function1<T, T>.invoke(ndArray: NDArray<T>): NDArray<T> = ndArray.transform { _, value -> this(value) }
|
||||||
|
|
||||||
|
/* plus and minus */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Summation operation for [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.plus(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg + value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse sum operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.plus(arg: NDArray<T>): NDArray<T> = arg + this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtraction operation between [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.minus(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg - value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse minus operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.minus(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
|
||||||
|
with(arg.context.field){
|
||||||
|
this@minus - value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* prod and div */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Product operation for [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.times(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg * value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse product operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.times(arg: NDArray<T>): NDArray<T> = arg * this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Division operation between [NDArray] and single element
|
||||||
|
*/
|
||||||
|
operator fun <T> NDArray<T>.div(arg: T): NDArray<T> = transform { _, value ->
|
||||||
|
with(context.field){
|
||||||
|
arg / value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reverse division operation
|
||||||
|
*/
|
||||||
|
operator fun <T> T.div(arg: NDArray<T>): NDArray<T> = arg.transform { _, value ->
|
||||||
|
with(arg.context.field){
|
||||||
|
this@div/ value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a platform-specific NDArray of doubles
|
* Create a platform-specific NDArray of doubles
|
||||||
*/
|
*/
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import org.junit.Assert.assertEquals
|
import org.junit.Assert.assertEquals
|
||||||
|
import kotlin.math.pow
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
class RealNDFieldTest {
|
class RealNDFieldTest {
|
||||||
@ -14,8 +15,8 @@ class RealNDFieldTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testProduct(){
|
fun testProduct() {
|
||||||
val product = array1*array2
|
val product = array1 * array2
|
||||||
assertEquals(0.0, product[2, 2], 0.1)
|
assertEquals(0.0, product[2, 2], 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,11 +25,18 @@ class RealNDFieldTest {
|
|||||||
|
|
||||||
val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
val array = real2DArray(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
||||||
|
|
||||||
for(i in 0..2){
|
for (i in 0..2) {
|
||||||
for(j in 0..2){
|
for (j in 0..2) {
|
||||||
val expected= (i * 10 + j).toDouble()
|
val expected = (i * 10 + j).toDouble()
|
||||||
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
|
assertEquals("Error at index [$i, $j]", expected, array[i, j], 0.1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testExternalFunction() {
|
||||||
|
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
||||||
|
val result = function(array1) + 1.0
|
||||||
|
assertEquals(10.0, result[1,1],0.01)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user