Generic definition for NDArray
This commit is contained in:
parent
6c25042f0f
commit
ba63b2e373
@ -1,6 +1,8 @@
|
||||
description = "Platform-independent interfaces for kotlin maths"
|
||||
plugins{
|
||||
id "kotlin-platform-common"
|
||||
}
|
||||
|
||||
apply plugin: 'kotlin-platform-common'
|
||||
description = "Platform-independent interfaces for kotlin maths"
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
@ -12,3 +14,9 @@ dependencies {
|
||||
testCompile "org.jetbrains.kotlin:kotlin-test-common:$kotlin_version"
|
||||
}
|
||||
|
||||
kotlin {
|
||||
experimental {
|
||||
coroutines "enable"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,93 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import kotlin.coroutines.experimental.buildSequence
|
||||
|
||||
|
||||
/**
|
||||
* A generic buffer for both primitives and objects
|
||||
*/
|
||||
interface Buffer<T> {
|
||||
operator fun get(index: Int): T
|
||||
operator fun set(index: Int, value: T)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic implementation of NDField based on continuous buffer
|
||||
*/
|
||||
abstract class BufferNDField<T>(shape: List<Int>, field: Field<T>) : NDField<T>(shape, field) {
|
||||
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
private val strides: List<Int> by lazy {
|
||||
ArrayList<Int>(shape.size).apply {
|
||||
var current = 1
|
||||
add(1)
|
||||
shape.forEach {
|
||||
current *= it
|
||||
add(current)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected fun offset(index: List<Int>): Int {
|
||||
return index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) {
|
||||
throw RuntimeException("Index out of shape bounds: ($i,$value)")
|
||||
}
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
}
|
||||
|
||||
protected fun index(offset: Int): List<Int>{
|
||||
return buildSequence {
|
||||
var current = offset
|
||||
var strideIndex = strides.size-2
|
||||
while (strideIndex>=0){
|
||||
yield(current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex--
|
||||
}
|
||||
}.toList().reversed()
|
||||
}
|
||||
|
||||
private val capacity: Int
|
||||
get() = strides[shape.size]
|
||||
|
||||
|
||||
protected abstract fun createBuffer(capacity: Int, initializer: (Int) -> T): Buffer<T>
|
||||
|
||||
override fun produce(initializer: (List<Int>) -> T): NDArray<T> {
|
||||
val buffer = createBuffer(capacity){initializer(index(it))}
|
||||
return BufferNDArray(this, buffer)
|
||||
}
|
||||
|
||||
|
||||
class BufferNDArray<T>(override val context: BufferNDField<T>, val data: Buffer<T>) : NDArray<T> {
|
||||
|
||||
override fun get(vararg index: Int): T {
|
||||
return data[context.offset(index.asList())]
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is BufferNDArray<*>) return false
|
||||
|
||||
if (context != other.context) return false
|
||||
if (data != other.data) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = context.hashCode()
|
||||
result = 31 * result + data.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
override val self: NDArray<T> get() = this
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,15 +198,3 @@ operator fun <T> T.div(arg: NDArray<T>): NDArray<T> = arg.transform { _, value -
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a platform-specific NDArray of doubles
|
||||
*/
|
||||
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Double>
|
||||
|
||||
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Double> {
|
||||
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
}
|
||||
|
||||
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Double> {
|
||||
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
|
||||
/**
|
||||
* Create a platform-optimized NDArray of doubles
|
||||
*/
|
||||
expect fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double = { 0.0 }): NDArray<Double>
|
||||
|
||||
fun real2DArray(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDArray<Double> {
|
||||
return realNDArray(listOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
}
|
||||
|
||||
fun real3DArray(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDArray<Double> {
|
||||
return realNDArray(listOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
}
|
||||
|
||||
|
||||
class SimpleNDField<T: Any>(field: Field<T>, shape: List<Int>) : BufferNDField<T>(shape, field) {
|
||||
override fun createBuffer(capacity: Int, initializer: (Int) -> T): Buffer<T> {
|
||||
val array = ArrayList<T>(capacity)
|
||||
(0 until capacity).forEach {
|
||||
array.add(initializer(it))
|
||||
}
|
||||
|
||||
return object : Buffer<T> {
|
||||
override fun get(index: Int): T = array[index]
|
||||
|
||||
override fun set(index: Int, value: T) {
|
||||
array[index] = initializer(index)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun <T: Any> simpleNDArray(field: Field<T>, shape: List<Int>, initializer: (List<Int>) -> T): NDArray<T> {
|
||||
return SimpleNDField(field, shape).produce { initializer(it) }
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
|
||||
class SimpleNDFieldTest{
|
||||
@Test
|
||||
fun testStrides(){
|
||||
val ndArray = simpleNDArray(DoubleField, listOf(10,10)){(it[0]+it[1]).toDouble()}
|
||||
assertEquals(ndArray[5,5], 10.0)
|
||||
}
|
||||
|
||||
}
|
@ -1,4 +1,7 @@
|
||||
apply plugin: 'kotlin-platform-jvm'
|
||||
plugins{
|
||||
id "kotlin-platform-jvm"
|
||||
id "me.champeau.gradle.jmh" version "0.4.5"
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
|
@ -0,0 +1,53 @@
|
||||
package scietifik.kmath.structures
|
||||
|
||||
import org.openjdk.jmh.annotations.*
|
||||
import java.nio.IntBuffer
|
||||
|
||||
|
||||
@Fork(1)
|
||||
@Warmup(iterations = 2)
|
||||
@Measurement(iterations = 50)
|
||||
@State(Scope.Benchmark)
|
||||
open class ArrayBenchmark {
|
||||
|
||||
lateinit var array: IntArray
|
||||
lateinit var arrayBuffer: IntBuffer
|
||||
lateinit var nativeBuffer: IntBuffer
|
||||
|
||||
@Setup
|
||||
fun setup() {
|
||||
array = IntArray(10000) { it }
|
||||
arrayBuffer = IntBuffer.wrap(array)
|
||||
nativeBuffer = IntBuffer.allocate(10000)
|
||||
for (i in 0 until 10000) {
|
||||
nativeBuffer.put(i,i)
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun benchmarkArrayRead() {
|
||||
var res = 0
|
||||
for (i in 1..10000) {
|
||||
res += array[10000 - i]
|
||||
}
|
||||
print(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun benchmarkBufferRead() {
|
||||
var res = 0
|
||||
for (i in 1..10000) {
|
||||
res += arrayBuffer.get(10000 - i)
|
||||
}
|
||||
print(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun nativeBufferRead() {
|
||||
var res = 0
|
||||
for (i in 1..10000) {
|
||||
res += nativeBuffer.get(10000 - i)
|
||||
}
|
||||
print(res)
|
||||
}
|
||||
}
|
@ -3,78 +3,22 @@ package scientifik.kmath.structures
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import java.nio.DoubleBuffer
|
||||
|
||||
private class RealNDField(shape: List<Int>) : NDField<Double>(shape, DoubleField) {
|
||||
private class RealNDField(shape: List<Int>) : BufferNDField<Double>(shape, DoubleField) {
|
||||
override fun createBuffer(capacity: Int, initializer: (Int) -> Double): Buffer<Double> {
|
||||
val array = DoubleArray(capacity, initializer)
|
||||
val buffer = DoubleBuffer.wrap(array)
|
||||
return object : Buffer<Double> {
|
||||
override fun get(index: Int): Double = buffer.get(index)
|
||||
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
private val strides: List<Int> by lazy {
|
||||
ArrayList<Int>(shape.size).apply {
|
||||
var current = 1
|
||||
add(1)
|
||||
shape.forEach {
|
||||
current *= it
|
||||
add(current)
|
||||
override fun set(index: Int, value: Double) {
|
||||
buffer.put(index, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun offset(index: List<Int>): Int {
|
||||
return index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) {
|
||||
throw RuntimeException("Index out of shape bounds: ($i,$value)")
|
||||
}
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
}
|
||||
|
||||
val capacity: Int
|
||||
get() = strides[shape.size]
|
||||
|
||||
|
||||
override fun produce(initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||
//TODO use sparse arrays for large capacities
|
||||
val buffer = DoubleBuffer.allocate(capacity)
|
||||
//FIXME there could be performance degradation due to iteration procedure. Replace by straight iteration
|
||||
NDArray.iterateIndexes(shape).forEach {
|
||||
buffer.put(offset(it), initializer(it))
|
||||
}
|
||||
return RealNDArray(this, buffer)
|
||||
}
|
||||
|
||||
class RealNDArray(override val context: RealNDField, val data: DoubleBuffer) : NDArray<Double> {
|
||||
|
||||
override fun get(vararg index: Int): Double {
|
||||
return data.get(context.offset(index.asList()))
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (javaClass != other?.javaClass) return false
|
||||
|
||||
other as RealNDArray
|
||||
|
||||
if (context.shape != other.context.shape) return false
|
||||
if (data != other.data) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var result = context.shape.hashCode()
|
||||
result = 31 * result + data.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
//TODO generate fixed hash code for quick comparison?
|
||||
|
||||
|
||||
override val self: NDArray<Double> get() = this
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
actual fun realNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Double> {
|
||||
//TODO cache fields?
|
||||
//TODO create a cache for fields to save time generating strides?
|
||||
|
||||
return RealNDField(shape).produce { initializer(it) }
|
||||
}
|
Loading…
Reference in New Issue
Block a user