Fixed multiple problems with NDArray. Added test for summation
This commit is contained in:
parent
f931004b5c
commit
7a507cd281
@ -50,10 +50,10 @@ interface SpaceElement<S : SpaceElement<S>> {
|
|||||||
*/
|
*/
|
||||||
val self: S
|
val self: S
|
||||||
|
|
||||||
operator fun plus(b: S): S = with(context) { self + b }
|
operator fun plus(b: S): S = context.add(self, b)
|
||||||
operator fun minus(b: S): S = with(context) { self - b }
|
operator fun minus(b: S): S = context.add(self, context.multiply(b, -1.0))
|
||||||
operator fun times(k: Number): S = with(context) { self * k }
|
operator fun times(k: Number): S = context.multiply(self, k.toDouble())
|
||||||
operator fun div(k: Number): S = with(context) { self / k }
|
operator fun div(k: Number): S = context.multiply(self, 1.0 / k.toDouble())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -80,7 +80,7 @@ interface Ring<T> : Space<T> {
|
|||||||
interface RingElement<S : RingElement<S>> : SpaceElement<S> {
|
interface RingElement<S : RingElement<S>> : SpaceElement<S> {
|
||||||
override val context: Ring<S>
|
override val context: Ring<S>
|
||||||
|
|
||||||
operator fun times(b: S): S = with(context) { self * b }
|
operator fun times(b: S): S = context.multiply(self, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -99,5 +99,5 @@ interface Field<T> : Ring<T> {
|
|||||||
interface FieldElement<S : FieldElement<S>> : RingElement<S> {
|
interface FieldElement<S : FieldElement<S>> : RingElement<S> {
|
||||||
override val context: Field<S>
|
override val context: Field<S>
|
||||||
|
|
||||||
operator fun div(b: S): S = with(context) { self / b }
|
operator fun div(b: S): S = context.divide(self, b)
|
||||||
}
|
}
|
@ -105,7 +105,7 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
|
|||||||
return if (shape.size == 1) {
|
return if (shape.size == 1) {
|
||||||
(0 until shape[0]).asSequence().map { listOf(it) }
|
(0 until shape[0]).asSequence().map { listOf(it) }
|
||||||
} else {
|
} else {
|
||||||
val tailShape = ArrayList(shape).apply { remove(0) }
|
val tailShape = ArrayList(shape).apply { removeAt(0) }
|
||||||
val tailSequence: List<List<Int>> = iterateIndexes(tailShape).toList()
|
val tailSequence: List<List<Int>> = iterateIndexes(tailShape).toList()
|
||||||
(0 until shape[0]).asSequence().map { firstIndex ->
|
(0 until shape[0]).asSequence().map { firstIndex ->
|
||||||
//adding first element to each of provided index lists
|
//adding first element to each of provided index lists
|
||||||
@ -116,5 +116,15 @@ interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a platform-specific NDArray of doubles
|
||||||
|
*/
|
||||||
|
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Real>
|
||||||
|
|
||||||
expect fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real>
|
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Real> {
|
||||||
|
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Real> {
|
||||||
|
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||||
|
}
|
@ -1,6 +1,5 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
|
||||||
import scientifik.kmath.operations.Real
|
import scientifik.kmath.operations.Real
|
||||||
import scientifik.kmath.operations.RealField
|
import scientifik.kmath.operations.RealField
|
||||||
import java.nio.DoubleBuffer
|
import java.nio.DoubleBuffer
|
||||||
@ -13,8 +12,9 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
|||||||
private val strides: List<Int> by lazy {
|
private val strides: List<Int> by lazy {
|
||||||
ArrayList<Int>(shape.size).apply {
|
ArrayList<Int>(shape.size).apply {
|
||||||
var current = 1
|
var current = 1
|
||||||
shape.forEach{
|
add(0)
|
||||||
current *=it
|
shape.forEach {
|
||||||
|
current *= it
|
||||||
add(current)
|
add(current)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -30,7 +30,7 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
val capacity: Int
|
val capacity: Int
|
||||||
get() = strides[shape.size - 1]
|
get() = strides[shape.size]
|
||||||
|
|
||||||
|
|
||||||
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
|
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
|
||||||
@ -39,26 +39,42 @@ private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
|||||||
NDArray.iterateIndexes(shape).forEach {
|
NDArray.iterateIndexes(shape).forEach {
|
||||||
buffer.put(offset(it), initializer(it).value)
|
buffer.put(offset(it), initializer(it).value)
|
||||||
}
|
}
|
||||||
return RealNDArray(buffer)
|
return RealNDArray(this, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
inner class RealNDArray(val data: DoubleBuffer) : NDArray<Real> {
|
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Real> {
|
||||||
|
|
||||||
override val context: Field<NDArray<Real>>
|
|
||||||
get() = this@RealNDField
|
|
||||||
|
|
||||||
override fun get(vararg index: Int): Real {
|
override fun get(vararg index: Int): Real {
|
||||||
return Real(data.get(offset(index.asList())))
|
return Real(data.get(context.offset(index.asList())))
|
||||||
}
|
}
|
||||||
|
|
||||||
override val self: NDArray<Real>
|
override fun equals(other: Any?): Boolean {
|
||||||
get() = this
|
if (this === other) return true
|
||||||
|
if (javaClass != other?.javaClass) return false
|
||||||
|
|
||||||
|
other as RealNDArray
|
||||||
|
|
||||||
|
if (context.shape != other.context.shape) return false
|
||||||
|
if (data != other.data) return false
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = context.shape.hashCode()
|
||||||
|
result = 31 * result + data.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO generate fixed hash code for quick comparison?
|
||||||
|
|
||||||
|
|
||||||
|
override val self: NDArray<Real> = this
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
actual fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
||||||
//TODO cache fields?
|
//TODO cache fields?
|
||||||
return RealNDField(shape).produce { Real(initializer(it)) }
|
return RealNDField(shape).produce { Real(initializer(it)) }
|
||||||
}
|
}
|
@ -0,0 +1,15 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import org.junit.Assert.assertEquals
|
||||||
|
import kotlin.test.Test
|
||||||
|
|
||||||
|
class RealNDFieldTest {
|
||||||
|
val array1 = real2DArray(3, 3) { i, j -> (i + j).toDouble() }
|
||||||
|
val array2 = real2DArray(3, 3) { i, j -> (i - j).toDouble() }
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testSum() {
|
||||||
|
val sum = array1 + array2
|
||||||
|
assertEquals(4.0, sum[2, 2].value, 0.1)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user