Fixed multiple problems with NDArray. Added test for summation
This commit is contained in:
parent
f931004b5c
commit
7a507cd281
@ -50,10 +50,10 @@ interface SpaceElement<S : SpaceElement<S>> {
|
||||
*/
|
||||
val self: S
|
||||
|
||||
operator fun plus(b: S): S = with(context) { self + b }
|
||||
operator fun minus(b: S): S = with(context) { self - b }
|
||||
operator fun times(k: Number): S = with(context) { self * k }
|
||||
operator fun div(k: Number): S = with(context) { self / k }
|
||||
operator fun plus(b: S): S = context.add(self, b)
|
||||
operator fun minus(b: S): S = context.add(self, context.multiply(b, -1.0))
|
||||
operator fun times(k: Number): S = context.multiply(self, k.toDouble())
|
||||
operator fun div(k: Number): S = context.multiply(self, 1.0 / k.toDouble())
|
||||
}
|
||||
|
||||
/**
|
||||
@ -80,7 +80,7 @@ interface Ring<T> : Space<T> {
|
||||
interface RingElement<S : RingElement<S>> : SpaceElement<S> {
|
||||
override val context: Ring<S>
|
||||
|
||||
operator fun times(b: S): S = with(context) { self * b }
|
||||
operator fun times(b: S): S = context.multiply(self, b)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -99,5 +99,5 @@ interface Field<T> : Ring<T> {
|
||||
interface FieldElement<S : FieldElement<S>> : RingElement<S> {
|
||||
override val context: Field<S>
|
||||
|
||||
operator fun div(b: S): S = with(context) { self / b }
|
||||
operator fun div(b: S): S = context.divide(self, b)
|
||||
}
|
@ -105,7 +105,7 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
|
||||
return if (shape.size == 1) {
|
||||
(0 until shape[0]).asSequence().map { listOf(it) }
|
||||
} else {
|
||||
val tailShape = ArrayList(shape).apply { remove(0) }
|
||||
val tailShape = ArrayList(shape).apply { removeAt(0) }
|
||||
val tailSequence: List<List<Int>> = iterateIndexes(tailShape).toList()
|
||||
(0 until shape[0]).asSequence().map { firstIndex ->
|
||||
//adding first element to each of provided index lists
|
||||
@ -116,5 +116,15 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a platform-specific NDArray of doubles
|
||||
*/
|
||||
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real>
|
||||
|
||||
expect fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real>
|
||||
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> {
|
||||
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
}
|
||||
|
||||
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> {
|
||||
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Real
|
||||
import scientifik.kmath.operations.RealField
|
||||
import java.nio.DoubleBuffer
|
||||
@ -13,8 +12,9 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||
private val strides: List<Int> by lazy {
|
||||
ArrayList<Int>(shape.size).apply {
|
||||
var current = 1
|
||||
shape.forEach{
|
||||
current *=it
|
||||
add(0)
|
||||
shape.forEach {
|
||||
current *= it
|
||||
add(current)
|
||||
}
|
||||
}
|
||||
@ -30,7 +30,7 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||
}
|
||||
|
||||
val capacity: Int
|
||||
get() = strides[shape.size - 1]
|
||||
get() = strides[shape.size]
|
||||
|
||||
|
||||
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
|
||||
@ -39,26 +39,42 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||
NDArray.iterateIndexes(shape).forEach {
|
||||
buffer.put(offset(it), initializer(it).value)
|
||||
}
|
||||
return RealNDArray(buffer)
|
||||
return RealNDArray(this, buffer)
|
||||
}
|
||||
|
||||
inner class RealNDArray(val data: DoubleBuffer) : NDArray<Real> {
|
||||
|
||||
override val context: Field<NDArray<Real>>
|
||||
get() = this@RealNDField
|
||||
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> {
|
||||
|
||||
override fun get(vararg index: Int): Real {
|
||||
return Real(data.get(offset(index.asList())))
|
||||
return Real(data.get(context.offset(index.asList())))
|
||||
}
|
||||
|
||||
override val self: NDArray<Real>
|
||||
get() = this
|
||||
}
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (javaClass != other?.javaClass) return false
|
||||
|
||||
other as RealNDArray
|
||||
|
||||
if (context.shape != other.context.shape) return false
|
||||
if (data != other.data) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = context.shape.hashCode()
|
||||
result = 31 * result + data.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
//TODO generate fixed hash code for quick comparison?
|
||||
|
||||
|
||||
override val self: NDArray<Real> = this
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
actual fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
||||
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
||||
//TODO cache fields?
|
||||
return RealNDField(shape).produce { Real(initializer(it)) }
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import org.junit.Assert.assertEquals
|
||||
import kotlin.test.Test
|
||||
|
||||
class RealNDFieldTest {
|
||||
val array1 = real2DArray(3, 3) { i, j -> (i + j).toDouble() }
|
||||
val array2 = real2DArray(3, 3) { i, j -> (i - j).toDouble() }
|
||||
|
||||
@Test
|
||||
fun testSum() {
|
||||
val sum = array1 + array2
|
||||
assertEquals(4.0, sum[2, 2].value, 0.1)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user