Fixed #5 by removing requirement of NDArray element to be Field element and introducing DoubleField without corresponding element.

This commit is contained in:
Alexander Nozik 2018-05-02 22:32:20 +03:00
parent bd46b66080
commit 7c3d561c63
8 changed files with 47 additions and 53 deletions

View File

@ -1,3 +1,14 @@
buildscript {
ext.kotlin_version = '1.2.41'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
group = 'scientifik'
version = '0.1 - SNAPSHOT'

View File

@ -1,14 +1,3 @@
buildscript {
ext.kotlin_version = '1.2.40'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
description = "Platform-independent interfaces for kotlin maths"
apply plugin: 'kotlin-platform-common'

View File

@ -90,7 +90,7 @@ interface Field<T> : Ring<T> {
fun divide(a: T, b: T): T
operator fun T.div(b: T): T = divide(this, b)
operator fun Double.div(b: T) = this * divide(one, b)
operator fun Number.div(b: T) = this * divide(one, b)
}
/**

View File

@ -76,3 +76,11 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
//TODO is it convenient?
operator fun not() = conjugate
}
object DoubleField : Field<Double> {
override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, b: Double): Double = a * b
override val one: Double = 1.0
override fun divide(a: Double, b: Double): Double = a / b
}

View File

@ -2,23 +2,22 @@ package scientifik.kmath.structures
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
import scientifik.kmath.operations.Real
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
/**
* Field for n-dimensional arrays.
* @param shape - the list of dimensions of the array
* @param elementField - operations field defined on individual array element
* @param field - operations field defined on individual array element
*/
abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementField: Field<T>) : Field<NDArray<T>> {
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
/**
* Create new instance of NDArray using field shape and given initializer
*/
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
override val zero: NDArray<T>
get() = produce { elementField.zero }
get() = produce { this.field.zero }
private fun checkShape(vararg arrays: NDArray<T>) {
arrays.forEach {
@ -33,7 +32,7 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
*/
override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
checkShape(a, b)
return produce { a[it] + b[it] }
return produce { with(field) { a[it] + b[it] } }
}
/**
@ -41,18 +40,18 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
*/
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
checkShape(a)
return produce { a[it] * k }
return produce { with(field) {a[it] * k} }
}
override val one: NDArray<T>
get() = produce { elementField.one }
get() = produce { this.field.one }
/**
* Element-by-element multiplication
*/
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
checkShape(a)
return produce { a[it] * b[it] }
return produce { with(field) {a[it] * b[it]} }
}
/**
@ -60,12 +59,12 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
*/
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
checkShape(a)
return produce { a[it] / b[it] }
return produce { with(field) {a[it] / b[it]} }
}
}
interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
/**
* The list of dimensions of this NDArray
@ -119,12 +118,12 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
/**
* Create a platform-specific NDArray of doubles
*/
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real>
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Double>
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> {
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Double> {
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> {
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Double> {
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}

View File

@ -1,15 +1,3 @@
buildscript {
ext.kotlin_version = '1.2.40'
repositories {
mavenCentral()
}
dependencies {
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
}
apply plugin: 'kotlin-platform-jvm'
repositories {

View File

@ -1,10 +1,9 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.Real
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.DoubleField
import java.nio.DoubleBuffer
private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField) {
/**
* Strides for memory access
@ -33,19 +32,19 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
get() = strides[shape.size]
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
//TODO use sparse arrays for large capacities
val buffer = DoubleBuffer.allocate(capacity)
NDArray.iterateIndexes(shape).forEach {
buffer.put(offset(it), initializer(it).value)
buffer.put(offset(it), initializer(it))
}
return RealNDArray(this, buffer)
}
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> {
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Double> {
override fun get(vararg index: Int): Real {
return Real(data.get(context.offset(index.asList())))
override fun get(vararg index: Int): Double {
return data.get(context.offset(index.asList()))
}
override fun equals(other: Any?): Boolean {
@ -69,12 +68,12 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
//TODO generate fixed hash code for quick comparison?
override val self: NDArray<Real> = this
override val self: NDArray<Double> = this
}
}
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Double> {
//TODO cache fields?
return RealNDField(shape).produce { Real(initializer(it)) }
return RealNDField(shape).produce { initializer(it) }
}

View File

@ -10,13 +10,13 @@ class RealNDFieldTest {
@Test
fun testSum() {
val sum = array1 + array2
assertEquals(4.0, sum[2, 2].toDouble(), 0.1)
assertEquals(4.0, sum[2, 2], 0.1)
}
@Test
fun testProduct(){
val product = array1*array2
assertEquals(0.0, product[2, 2].toDouble(), 0.1)
assertEquals(0.0, product[2, 2], 0.1)
}
@Test
@ -27,7 +27,7 @@ class RealNDFieldTest {
for(i in 0..2){
for(j in 0..2){
val expected= (i * 10 + j).toDouble()
assertEquals("Error at index [$i, $j]", expected, array[i,j].value, 0.1)
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
}
}
}