Fixed #5 by removing requirement of NDArray element to be Field element and introducing DoubleField without corresponding element.
This commit is contained in:
parent
bd46b66080
commit
7c3d561c63
15
build.gradle
15
build.gradle
@ -1,3 +1,14 @@
|
|||||||
group = 'scientifik'
|
buildscript {
|
||||||
version = '0.1-SNAPSHOT'
|
ext.kotlin_version = '1.2.41'
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
dependencies {
|
||||||
|
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group = 'scientifik'
|
||||||
|
version = '0.1 - SNAPSHOT'
|
||||||
|
|
||||||
|
@ -1,14 +1,3 @@
|
|||||||
buildscript {
|
|
||||||
ext.kotlin_version = '1.2.40'
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
mavenCentral()
|
|
||||||
}
|
|
||||||
dependencies {
|
|
||||||
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
description = "Platform-independent interfaces for kotlin maths"
|
description = "Platform-independent interfaces for kotlin maths"
|
||||||
|
|
||||||
apply plugin: 'kotlin-platform-common'
|
apply plugin: 'kotlin-platform-common'
|
||||||
|
@ -90,7 +90,7 @@ interface Field<T> : Ring<T> {
|
|||||||
fun divide(a: T, b: T): T
|
fun divide(a: T, b: T): T
|
||||||
|
|
||||||
operator fun T.div(b: T): T = divide(this, b)
|
operator fun T.div(b: T): T = divide(this, b)
|
||||||
operator fun Double.div(b: T) = this * divide(one, b)
|
operator fun Number.div(b: T) = this * divide(one, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -76,3 +76,11 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
|
|||||||
//TODO is it convenient?
|
//TODO is it convenient?
|
||||||
operator fun not() = conjugate
|
operator fun not() = conjugate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object DoubleField : Field<Double> {
|
||||||
|
override val zero: Double = 0.0
|
||||||
|
override fun add(a: Double, b: Double): Double = a + b
|
||||||
|
override fun multiply(a: Double, b: Double): Double = a * b
|
||||||
|
override val one: Double = 1.0
|
||||||
|
override fun divide(a: Double, b: Double): Double = a / b
|
||||||
|
}
|
@ -2,23 +2,22 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.Field
|
||||||
import scientifik.kmath.operations.FieldElement
|
import scientifik.kmath.operations.FieldElement
|
||||||
import scientifik.kmath.operations.Real
|
|
||||||
|
|
||||||
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field for n-dimensional arrays.
|
* Field for n-dimensional arrays.
|
||||||
* @param shape - the list of dimensions of the array
|
* @param shape - the list of dimensions of the array
|
||||||
* @param elementField - operations field defined on individual array element
|
* @param field - operations field defined on individual array element
|
||||||
*/
|
*/
|
||||||
abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementField: Field<T>) : Field<NDArray<T>> {
|
abstract class NDField<T>(val shape: List<Int>, val field: Field<T>) : Field<NDArray<T>> {
|
||||||
/**
|
/**
|
||||||
* Create new instance of NDArray using field shape and given initializer
|
* Create new instance of NDArray using field shape and given initializer
|
||||||
*/
|
*/
|
||||||
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
||||||
|
|
||||||
override val zero: NDArray<T>
|
override val zero: NDArray<T>
|
||||||
get() = produce { elementField.zero }
|
get() = produce { this.field.zero }
|
||||||
|
|
||||||
private fun checkShape(vararg arrays: NDArray<T>) {
|
private fun checkShape(vararg arrays: NDArray<T>) {
|
||||||
arrays.forEach {
|
arrays.forEach {
|
||||||
@ -33,7 +32,7 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
|
|||||||
*/
|
*/
|
||||||
override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
checkShape(a, b)
|
checkShape(a, b)
|
||||||
return produce { a[it] + b[it] }
|
return produce { with(field) { a[it] + b[it] } }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -41,18 +40,18 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
|
|||||||
*/
|
*/
|
||||||
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
|
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
|
||||||
checkShape(a)
|
checkShape(a)
|
||||||
return produce { a[it] * k }
|
return produce { with(field) {a[it] * k} }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val one: NDArray<T>
|
override val one: NDArray<T>
|
||||||
get() = produce { elementField.one }
|
get() = produce { this.field.one }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-by-element multiplication
|
* Element-by-element multiplication
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
checkShape(a)
|
checkShape(a)
|
||||||
return produce { a[it] * b[it] }
|
return produce { with(field) {a[it] * b[it]} }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -60,12 +59,12 @@ abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementFie
|
|||||||
*/
|
*/
|
||||||
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
checkShape(a)
|
checkShape(a)
|
||||||
return produce { a[it] / b[it] }
|
return produce { with(field) {a[it] / b[it]} }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
interface NDArray<T> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The list of dimensions of this NDArray
|
* The list of dimensions of this NDArray
|
||||||
@ -119,12 +118,12 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
|
|||||||
/**
|
/**
|
||||||
* Create a platform-specific NDArray of doubles
|
* Create a platform-specific NDArray of doubles
|
||||||
*/
|
*/
|
||||||
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real>
|
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Double>
|
||||||
|
|
||||||
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> {
|
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Double> {
|
||||||
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||||
}
|
}
|
||||||
|
|
||||||
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> {
|
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Double> {
|
||||||
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
}
|
}
|
@ -1,15 +1,3 @@
|
|||||||
buildscript {
|
|
||||||
ext.kotlin_version = '1.2.40'
|
|
||||||
|
|
||||||
repositories {
|
|
||||||
mavenCentral()
|
|
||||||
}
|
|
||||||
dependencies {
|
|
||||||
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
apply plugin: 'kotlin-platform-jvm'
|
apply plugin: 'kotlin-platform-jvm'
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.Real
|
import scientifik.kmath.operations.DoubleField
|
||||||
import scientifik.kmath.operations.RealField
|
|
||||||
import java.nio.DoubleBuffer
|
import java.nio.DoubleBuffer
|
||||||
|
|
||||||
private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Strides for memory access
|
* Strides for memory access
|
||||||
@ -33,19 +32,19 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
|||||||
get() = strides[shape.size]
|
get() = strides[shape.size]
|
||||||
|
|
||||||
|
|
||||||
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
|
override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||||
//TODO use sparse arrays for large capacities
|
//TODO use sparse arrays for large capacities
|
||||||
val buffer = DoubleBuffer.allocate(capacity)
|
val buffer = DoubleBuffer.allocate(capacity)
|
||||||
NDArray.iterateIndexes(shape).forEach {
|
NDArray.iterateIndexes(shape).forEach {
|
||||||
buffer.put(offset(it), initializer(it).value)
|
buffer.put(offset(it), initializer(it))
|
||||||
}
|
}
|
||||||
return RealNDArray(this, buffer)
|
return RealNDArray(this, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> {
|
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Double> {
|
||||||
|
|
||||||
override fun get(vararg index: Int): Real {
|
override fun get(vararg index: Int): Double {
|
||||||
return Real(data.get(context.offset(index.asList())))
|
return data.get(context.offset(index.asList()))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
@ -69,12 +68,12 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
|||||||
//TODO generate fixed hash code for quick comparison?
|
//TODO generate fixed hash code for quick comparison?
|
||||||
|
|
||||||
|
|
||||||
override val self: NDArray<Real> = this
|
override val self: NDArray<Double> = this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||||
//TODO cache fields?
|
//TODO cache fields?
|
||||||
return RealNDField(shape).produce { Real(initializer(it)) }
|
return RealNDField(shape).produce { initializer(it) }
|
||||||
}
|
}
|
@ -10,13 +10,13 @@ class RealNDFieldTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testSum() {
|
fun testSum() {
|
||||||
val sum = array1 + array2
|
val sum = array1 + array2
|
||||||
assertEquals(4.0, sum[2, 2].toDouble(), 0.1)
|
assertEquals(4.0, sum[2, 2], 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testProduct(){
|
fun testProduct(){
|
||||||
val product = array1*array2
|
val product = array1*array2
|
||||||
assertEquals(0.0, product[2, 2].toDouble(), 0.1)
|
assertEquals(0.0, product[2, 2], 0.1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -27,7 +27,7 @@ class RealNDFieldTest {
|
|||||||
for(i in 0..2){
|
for(i in 0..2){
|
||||||
for(j in 0..2){
|
for(j in 0..2){
|
||||||
val expected= (i * 10 + j).toDouble()
|
val expected= (i * 10 + j).toDouble()
|
||||||
assertEquals("Error at index [$i, $j]", expected, array[i,j].value, 0.1)
|
assertEquals("Error at index [$i, $j]", expected, array[i,j], 0.1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user