Fixed #5 by removing requirement of NDArray element to be Field element and introducing DoubleField without corresponding element.

This commit is contained in:
Alexander Nozik 2018-05-02 22:32:20 +03:00
parent bd46b66080
commit 7c3d561c63
8 changed files with 47 additions and 53 deletions

View File

@ -1,3 +1,14 @@
buildscript {
ext.kotlin_version = '1.2.41'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
group = 'scientifik' group = 'scientifik'
version = '0.1 - SNAPSHOT' version = '0.1 - SNAPSHOT'

View File

@ -1,14 +1,3 @@
buildscript {
ext.kotlin_version = '1.2.40'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
description = "Platform-independent interfaces for kotlin maths" description = "Platform-independent interfaces for kotlin maths"
apply plugin: 'kotlin-platform-common' apply plugin: 'kotlin-platform-common'

View File

@ -90,7 +90,7 @@ interface Field<T> : Ring<T> {
fun divide(a: T, b: T): T fun divide(a: T, b: T): T
operator fun T.div(b: T): T = divide(this, b) operator fun T.div(b: T): T = divide(this, b)
operator fun Double.div(b: T) = this * divide(one, b) operator fun Number.div(b: T) = this * divide(one, b)
} }
/** /**

View File

@ -76,3 +76,11 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
//TODO is it convenient? //TODO is it convenient?
operator fun not() = conjugate operator fun not() = conjugate
} }
object DoubleField : Field<Double> {
override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, b: Double): Double = a * b
override val one: Double = 1.0
override fun divide(a: Double, b: Double): Double = a / b
}

View File

@ -2,23 +2,22 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement import scientifik.kmath.operations.FieldElement
import scientifik.kmath.operations.Real
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException() class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
/** /**
* Field for n-dimensional arrays. * Field for n-dimensional arrays.
* @param shape - the list of dimensions of the array * @param shape - the list of dimensions of the array
* @param elementField - operations field defined on individual array element * @param field - operations field defined on individual array element
*/ */
abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementField: Field<T>) : Field<NDArray<T>> { abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
/** /**
* Create new instance of NDArray using field shape and given initializer * Create new instance of NDArray using field shape and given initializer
*/ */
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T> abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
override val zero: NDArray<T> override val zero: NDArray<T>
get() = produce { elementField.zero } get() = produce { this.field.zero }
private fun checkShape(vararg arrays: NDArray<T>) { private fun checkShape(vararg arrays: NDArray<T>) {
arrays.forEach { arrays.forEach {
@ -33,7 +32,7 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
*/ */
override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> { override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
checkShape(a, b) checkShape(a, b)
return produce { a[it] + b[it] } return produce { with(field) { a[it] + b[it] } }
} }
/** /**
@ -41,18 +40,18 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
*/ */
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> { override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
checkShape(a) checkShape(a)
return produce { a[it] * k } return produce { with(field) {a[it] * k} }
} }
override val one: NDArray<T> override val one: NDArray<T>
get() = produce { elementField.one } get() = produce { this.field.one }
/** /**
* Element-by-element multiplication * Element-by-element multiplication
*/ */
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> { override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
checkShape(a) checkShape(a)
return produce { a[it] * b[it] } return produce { with(field) {a[it] * b[it]} }
} }
/** /**
@ -60,12 +59,12 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
*/ */
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> { override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
checkShape(a) checkShape(a)
return produce { a[it] / b[it] } return produce { with(field) {a[it] / b[it]} }
} }
} }
interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> { interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
/** /**
* The list of dimensions of this NDArray * The list of dimensions of this NDArray
@ -119,12 +118,12 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
/** /**
* Create a platform-specific NDArray of doubles * Create a platform-specific NDArray of doubles
*/ */
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real> expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Double>
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> { fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Double> {
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) } return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
} }
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> { fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Double> {
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) } return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
} }

View File

@ -1,15 +1,3 @@
buildscript {
ext.kotlin_version = '1.2.40'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
apply plugin: 'kotlin-platform-jvm' apply plugin: 'kotlin-platform-jvm'
repositories { repositories {

View File

@ -1,10 +1,9 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.Real import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.RealField
import java.nio.DoubleBuffer import java.nio.DoubleBuffer
private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) { private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField) {
/** /**
* Strides for memory access * Strides for memory access
@ -33,19 +32,19 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
get() = strides[shape.size] get() = strides[shape.size]
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> { override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
//TODO use sparse arrays for large capacities //TODO use sparse arrays for large capacities
val buffer = DoubleBuffer.allocate(capacity) val buffer = DoubleBuffer.allocate(capacity)
NDArray.iterateIndexes(shape).forEach { NDArray.iterateIndexes(shape).forEach {
buffer.put(offset(it), initializer(it).value) buffer.put(offset(it), initializer(it))
} }
return RealNDArray(this, buffer) return RealNDArray(this, buffer)
} }
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> { class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Double> {
override fun get(vararg index: Int): Real { override fun get(vararg index: Int): Double {
return Real(data.get(context.offset(index.asList()))) return data.get(context.offset(index.asList()))
} }
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
@ -69,12 +68,12 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
//TODO generate fixed hash code for quick comparison? //TODO generate fixed hash code for quick comparison?
override val self: NDArray<Real> = this override val self: NDArray<Double> = this
} }
} }
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> { actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Double> {
//TODO cache fields? //TODO cache fields?
return RealNDField(shape).produce { Real(initializer(it)) } return RealNDField(shape).produce { initializer(it) }
} }

View File

@ -10,13 +10,13 @@ class RealNDFieldTest {
@Test @Test
fun testSum() { fun testSum() {
val sum = array1 + array2 val sum = array1 + array2
assertEquals(4.0, sum[2, 2].toDouble(), 0.1) assertEquals(4.0, sum[2, 2], 0.1)
} }
@Test @Test
fun testProduct(){ fun testProduct(){
val product = array1*array2 val product = array1*array2
assertEquals(0.0, product[2, 2].toDouble(), 0.1) assertEquals(0.0, product[2, 2], 0.1)
} }
@Test @Test
@ -27,7 +27,7 @@ class RealNDFieldTest {
for(i in 0..2){ for(i in 0..2){
for(j in 0..2){ for(j in 0..2){
val expected= (i * 10 + j).toDouble() val expected= (i * 10 + j).toDouble()
assertEquals("Error at index [$i, $j]", expected, array[i,j].value, 0.1) assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
} }
} }
} }