forked from kscience/kmath
Fixed #5 by removing requirement of NDArray element to be Field element and introducing DoubleField without corresponding element.
This commit is contained in:
parent
bd46b66080
commit
7c3d561c63
15
build.gradle
15
build.gradle
@ -1,3 +1,14 @@
|
||||
group = 'scientifik'
|
||||
version = '0.1-SNAPSHOT'
|
||||
buildscript {
|
||||
ext.kotlin_version = '1.2.41'
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
dependencies {
|
||||
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||
}
|
||||
}
|
||||
|
||||
group = 'scientifik'
|
||||
version = '0.1 - SNAPSHOT'
|
||||
|
||||
|
@ -1,14 +1,3 @@
|
||||
buildscript {
|
||||
ext.kotlin_version = '1.2.40'
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
dependencies {
|
||||
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||
}
|
||||
}
|
||||
|
||||
description = "Platform-independent interfaces for kotlin maths"
|
||||
|
||||
apply plugin: 'kotlin-platform-common'
|
||||
|
@ -90,7 +90,7 @@ interface Field<T> : Ring<T> {
|
||||
fun divide(a: T, b: T): T
|
||||
|
||||
operator fun T.div(b: T): T = divide(this, b)
|
||||
operator fun Double.div(b: T) = this * divide(one, b)
|
||||
operator fun Number.div(b: T) = this * divide(one, b)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -76,3 +76,11 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
|
||||
//TODO is it convenient?
|
||||
operator fun not() = conjugate
|
||||
}
|
||||
|
||||
object DoubleField : Field<Double> {
|
||||
override val zero: Double = 0.0
|
||||
override fun add(a: Double, b: Double): Double = a + b
|
||||
override fun multiply(a: Double, b: Double): Double = a * b
|
||||
override val one: Double = 1.0
|
||||
override fun divide(a: Double, b: Double): Double = a / b
|
||||
}
|
@ -2,23 +2,22 @@ package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
import scientifik.kmath.operations.Real
|
||||
|
||||
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
||||
|
||||
/**
|
||||
* Field for n-dimensional arrays.
|
||||
* @param shape - the list of dimensions of the array
|
||||
* @param elementField - operations field defined on individual array element
|
||||
* @param field - operations field defined on individual array element
|
||||
*/
|
||||
abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementField: Field<T>) : Field<NDArray<T>> {
|
||||
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
||||
/**
|
||||
* Create new instance of NDArray using field shape and given initializer
|
||||
*/
|
||||
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
||||
|
||||
override val zero: NDArray<T>
|
||||
get() = produce { elementField.zero }
|
||||
get() = produce { this.field.zero }
|
||||
|
||||
private fun checkShape(vararg arrays: NDArray<T>) {
|
||||
arrays.forEach {
|
||||
@ -33,7 +32,7 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
|
||||
*/
|
||||
override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||
checkShape(a, b)
|
||||
return produce { a[it] + b[it] }
|
||||
return produce { with(field) { a[it] + b[it] } }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -41,18 +40,18 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
|
||||
*/
|
||||
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
|
||||
checkShape(a)
|
||||
return produce { a[it] * k }
|
||||
return produce { with(field) {a[it] * k} }
|
||||
}
|
||||
|
||||
override val one: NDArray<T>
|
||||
get() = produce { elementField.one }
|
||||
get() = produce { this.field.one }
|
||||
|
||||
/**
|
||||
* Element-by-element multiplication
|
||||
*/
|
||||
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||
checkShape(a)
|
||||
return produce { a[it] * b[it] }
|
||||
return produce { with(field) {a[it] * b[it]} }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -60,12 +59,12 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
|
||||
*/
|
||||
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||
checkShape(a)
|
||||
return produce { a[it] / b[it] }
|
||||
return produce { with(field) {a[it] / b[it]} }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
||||
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
||||
|
||||
/**
|
||||
* The list of dimensions of this NDArray
|
||||
@ -119,12 +118,12 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
|
||||
/**
|
||||
* Create a platform-specific NDArray of doubles
|
||||
*/
|
||||
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real>
|
||||
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Double>
|
||||
|
||||
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> {
|
||||
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Double> {
|
||||
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
}
|
||||
|
||||
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> {
|
||||
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Double> {
|
||||
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
}
|
@ -1,15 +1,3 @@
|
||||
buildscript {
|
||||
ext.kotlin_version = '1.2.40'
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
dependencies {
|
||||
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
apply plugin: 'kotlin-platform-jvm'
|
||||
|
||||
repositories {
|
||||
|
@ -1,10 +1,9 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Real
|
||||
import scientifik.kmath.operations.RealField
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import java.nio.DoubleBuffer
|
||||
|
||||
private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||
private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField) {
|
||||
|
||||
/**
|
||||
* Strides for memory access
|
||||
@ -33,19 +32,19 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||
get() = strides[shape.size]
|
||||
|
||||
|
||||
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
|
||||
override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||
//TODO use sparse arrays for large capacities
|
||||
val buffer = DoubleBuffer.allocate(capacity)
|
||||
NDArray.iterateIndexes(shape).forEach {
|
||||
buffer.put(offset(it), initializer(it).value)
|
||||
buffer.put(offset(it), initializer(it))
|
||||
}
|
||||
return RealNDArray(this, buffer)
|
||||
}
|
||||
|
||||
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> {
|
||||
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Double> {
|
||||
|
||||
override fun get(vararg index: Int): Real {
|
||||
return Real(data.get(context.offset(index.asList())))
|
||||
override fun get(vararg index: Int): Double {
|
||||
return data.get(context.offset(index.asList()))
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
@ -69,12 +68,12 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||
//TODO generate fixed hash code for quick comparison?
|
||||
|
||||
|
||||
override val self: NDArray<Real> = this
|
||||
override val self: NDArray<Double> = this
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
||||
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||
//TODO cache fields?
|
||||
return RealNDField(shape).produce { Real(initializer(it)) }
|
||||
return RealNDField(shape).produce { initializer(it) }
|
||||
}
|
@ -10,13 +10,13 @@ class RealNDFieldTest {
|
||||
@Test
|
||||
fun testSum() {
|
||||
val sum = array1 + array2
|
||||
assertEquals(4.0, sum[2, 2].toDouble(), 0.1)
|
||||
assertEquals(4.0, sum[2, 2], 0.1)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testProduct(){
|
||||
val product = array1*array2
|
||||
assertEquals(0.0, product[2, 2].toDouble(), 0.1)
|
||||
assertEquals(0.0, product[2, 2], 0.1)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -27,7 +27,7 @@ class RealNDFieldTest {
|
||||
for(i in 0..2){
|
||||
for(j in 0..2){
|
||||
val expected= (i * 10 + j).toDouble()
|
||||
assertEquals("Error at index [$i, $j]", expected, array[i,j].value, 0.1)
|
||||
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user