Fixed multiple problems with NDArray. Added test for summation

This commit is contained in:
Alexander Nozik 2018-04-29 09:15:30 +03:00
parent f931004b5c
commit 7a507cd281
4 changed files with 63 additions and 22 deletions

View File

@ -50,10 +50,10 @@ interface SpaceElement<S : SpaceElement<S>> {
*/
val self: S
operator fun plus(b: S): S = with(context) { self + b }
operator fun minus(b: S): S = with(context) { self - b }
operator fun times(k: Number): S = with(context) { self * k }
operator fun div(k: Number): S = with(context) { self / k }
operator fun plus(b: S): S = context.add(self, b)
operator fun minus(b: S): S = context.add(self, context.multiply(b, -1.0))
operator fun times(k: Number): S = context.multiply(self, k.toDouble())
operator fun div(k: Number): S = context.multiply(self, 1.0 / k.toDouble())
}
/**
@ -80,7 +80,7 @@ interface Ring<T> : Space<T> {
interface RingElement<S : RingElement<S>> : SpaceElement<S> {
override val context: Ring<S>
operator fun times(b: S): S = with(context) { self * b }
operator fun times(b: S): S = context.multiply(self, b)
}
/**
@ -99,5 +99,5 @@ interface Field<T> : Ring<T> {
interface FieldElement<S : FieldElement<S>> : RingElement<S> {
override val context: Field<S>
operator fun div(b: S): S = with(context) { self / b }
operator fun div(b: S): S = context.divide(self, b)
}

View File

@ -105,7 +105,7 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
return if (shape.size == 1) {
(0 until shape[0]).asSequence().map { listOf(it) }
} else {
val tailShape = ArrayList(shape).apply { remove(0) }
val tailShape = ArrayList(shape).apply { removeAt(0) }
val tailSequence: List<List<Int>> = iterateIndexes(tailShape).toList()
(0 until shape[0]).asSequence().map { firstIndex ->
//adding first element to each of provided index lists
@ -116,5 +116,15 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
}
}
/**
* Create a platform-specific NDArray of doubles
*/
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real>
expect fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real>
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> {
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> {
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}

View File

@ -1,6 +1,5 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Real
import scientifik.kmath.operations.RealField
import java.nio.DoubleBuffer
@ -13,8 +12,9 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
private val strides: List<Int> by lazy {
ArrayList<Int>(shape.size).apply {
var current = 1
shape.forEach{
current *=it
add(0)
shape.forEach {
current *= it
add(current)
}
}
@ -30,7 +30,7 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
}
val capacity: Int
get() = strides[shape.size - 1]
get() = strides[shape.size]
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
@ -39,26 +39,42 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
NDArray.iterateIndexes(shape).forEach {
buffer.put(offset(it), initializer(it).value)
}
return RealNDArray(buffer)
return RealNDArray(this, buffer)
}
inner class RealNDArray(val data: DoubleBuffer) : NDArray<Real> {
override val context: Field<NDArray<Real>>
get() = this@RealNDField
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> {
override fun get(vararg index: Int): Real {
return Real(data.get(offset(index.asList())))
return Real(data.get(context.offset(index.asList())))
}
override val self: NDArray<Real>
get() = this
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as RealNDArray
if (context.shape != other.context.shape) return false
if (data != other.data) return false
return true
}
override fun hashCode(): Int {
var result = context.shape.hashCode()
result = 31 * result + data.hashCode()
return result
}
//TODO generate fixed hash code for quick comparison?
override val self: NDArray<Real> = this
}
}
actual fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
//TODO cache fields?
return RealNDField(shape).produce { Real(initializer(it)) }
}

View File

@ -0,0 +1,15 @@
package scientifik.kmath.structures
import org.junit.Assert.assertEquals
import kotlin.test.Test
class RealNDFieldTest {
val array1 = real2DArray(3, 3) { i, j -> (i + j).toDouble() }
val array2 = real2DArray(3, 3) { i, j -> (i - j).toDouble() }
@Test
fun testSum() {
val sum = array1 + array2
assertEquals(4.0, sum[2, 2].value, 0.1)
}
}