Specialized BufferNDField

This commit is contained in:
Alexander Nozik 2018-12-30 19:48:32 +03:00
parent cdef4122df
commit 334dadc6bf
17 changed files with 259 additions and 139 deletions

View File

@ -0,0 +1,36 @@
package scientifik.kmath.structures
import kotlin.system.measureTimeMillis
fun main(args: Array<String>) {
val n = 6000
val array = DoubleArray(n * n) { 1.0 }
val buffer = DoubleBuffer(array)
val strides = DefaultStrides(intArrayOf(n, n))
val structure = BufferNDStructure(strides, buffer)
measureTimeMillis {
var res: Double = 0.0
strides.indices().forEach { res = structure[it] }
} // warmup
val time1 = measureTimeMillis {
var res: Double = 0.0
strides.indices().forEach { res = structure[it] }
}
println("Structure reading finished in $time1 millis")
val time2 = measureTimeMillis {
var res: Double = 0.0
strides.indices().forEach { res = buffer[strides.offset(it)] }
}
println("Buffer reading finished in $time2 millis")
val time3 = measureTimeMillis {
var res: Double = 0.0
strides.indices().forEach { res = array[strides.offset(it)] }
}
println("Array reading finished in $time3 millis")
}

View File

@ -0,0 +1,39 @@
package scientifik.kmath.structures
import kotlin.system.measureTimeMillis
fun main(args: Array<String>) {
val n = 6000
val structure = NdStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
structure.map { it + 1 } // warm-up
val time1 = measureTimeMillis {
val res = structure.map { it + 1 }
}
println("Structure mapping finished in $time1 millis")
val array = DoubleArray(n*n){1.0}
val time2 = measureTimeMillis {
val target = DoubleArray(n*n)
val res = array.forEachIndexed{index, value ->
target[index] = value + 1
}
}
println("Array mapping finished in $time2 millis")
val buffer = DoubleBuffer(DoubleArray(n*n){1.0})
val time3 = measureTimeMillis {
val target = DoubleBuffer(DoubleArray(n*n))
val res = array.forEachIndexed{index, value ->
target[index] = value + 1
}
}
println("Buffer mapping finished in $time3 millis")
}

View File

@ -32,6 +32,6 @@ fun ClosedFloatingPointRange<Double>.toSequence(step: Double): Sequence<Double>
* Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints] * Convert double range to array of evenly spaced doubles, where the size of array equals [numPoints]
*/ */
fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray { fun ClosedFloatingPointRange<Double>.toGrid(numPoints: Int): DoubleArray {
if (numPoints < 2) error("Can't create grid with less than two points") if (numPoints < 2) error("Can't generic grid with less than two points")
return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i } return DoubleArray(numPoints) { i -> start + (endInclusive - start) / (numPoints - 1) * i }
} }

View File

@ -0,0 +1,22 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.Field
class BufferNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F> {
val strides = DefaultStrides(shape)
override inline fun produce(crossinline initializer: F.(IntArray) -> T): NDElement<T, F> {
return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
}
}
class BufferNDElement<T : Any, F : Field<T>>(override val context: BufferNDField<T, F>, private val buffer: Buffer<T>) : NDElement<T, F> {
override val self: NDStructure<T> get() = this
override val shape: IntArray get() = context.shape
override fun get(index: IntArray): T = buffer[context.strides.offset(index)]
override fun elements(): Sequence<Pair<IntArray, T>> = context.strides.indices().map { it to get(it) }
}

View File

@ -9,14 +9,15 @@ import scientifik.kmath.operations.TrigonometricOperations
/** /**
* NDField that supports [ExtendedField] operations on its elements * NDField that supports [ExtendedField] operations on its elements
*/ */
class ExtendedNDField<T : Any, F : ExtendedField<T>>(shape: IntArray, field: F) : NDField<T, F>(shape, field), inline class ExtendedNDField<T : Any, F : ExtendedField<T>>(private val ndField: NDField<T, F>) : NDField<T, F>,
TrigonometricOperations<NDStructure<T>>, TrigonometricOperations<NDStructure<T>>,
PowerOperations<NDStructure<T>>, PowerOperations<NDStructure<T>>,
ExponentialOperations<NDStructure<T>> { ExponentialOperations<NDStructure<T>> {
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> { override val shape: IntArray get() = ndField.shape
return NdStructure(shape, ::boxingBuffer) { field.initializer(it) } override val field: F get() = ndField.field
}
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer)
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> { override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> {
return produce { with(field) { power(arg[it], pow) } } return produce { with(field) { power(arg[it], pow) } }

View File

@ -15,28 +15,25 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
* @param field - operations field defined on individual array element * @param field - operations field defined on individual array element
* @param T the type of the element contained in NDArray * @param T the type of the element contained in NDArray
*/ */
abstract class NDField<T, F : Field<T>>(val shape: IntArray, val field: F) : Field<NDStructure<T>> { interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
abstract fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> val shape: IntArray
val field: F
/** /**
* Create new instance of NDArray using field shape and given initializer * Create new instance of NDArray using field shape and given initializer
* The producer takes list of indices as argument and returns contained value * The producer takes list of indices as argument and returns contained value
*/ */
fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = NDStructureElement(this, produceStructure(initializer)) fun produce(initializer: F.(IntArray) -> T): NDElement<T, F>
override val zero: NDElement<T, F> by lazy { override val zero: NDElement<T, F> get() = produce { zero }
produce { zero }
}
override val one: NDElement<T, F> by lazy { override val one: NDElement<T, F> get() = produce { one }
produce { one }
}
/** /**
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field * Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
*/ */
private fun checkShape(vararg elements: NDStructure<T>) { fun checkShape(vararg elements: NDStructure<T>) {
elements.forEach { elements.forEach {
if (!shape.contentEquals(it.shape)) { if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape) throw ShapeMismatchException(shape, it.shape)
@ -49,7 +46,7 @@ abstract class NDField<T, F : Field<T>>(val shape: IntArray, val field: F) : Fie
*/ */
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
checkShape(a, b) checkShape(a, b)
return produce { with(field) { a[it] + b[it] } } return produce { field.run { a[it] + b[it] } }
} }
/** /**
@ -57,7 +54,7 @@ abstract class NDField<T, F : Field<T>>(val shape: IntArray, val field: F) : Fie
*/ */
override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> { override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> {
checkShape(a) checkShape(a)
return produce { with(field) { a[it] * k } } return produce { field.run { a[it] * k } }
} }
/** /**
@ -65,7 +62,7 @@ abstract class NDField<T, F : Field<T>>(val shape: IntArray, val field: F) : Fie
*/ */
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
checkShape(a) checkShape(a)
return produce { with(field) { a[it] * b[it] } } return produce { field.run { a[it] * b[it] } }
} }
/** /**
@ -73,55 +70,72 @@ abstract class NDField<T, F : Field<T>>(val shape: IntArray, val field: F) : Fie
*/ */
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
checkShape(a) checkShape(a)
return produce { with(field) { a[it] / b[it] } } return produce { field.run { a[it] / b[it] } }
} }
// /**
// * Reverse sum operation companion object {
// */ /**
// operator fun T.plus(arg: NDElement<T, F>): NDElement<T, F> = arg + this * Create a nd-field for [Double] values
// */
// /** fun real(shape: IntArray) = ExtendedNDField(BufferNDField(shape, DoubleField, DoubleBufferFactory))
// * Reverse minus operation
// */ /**
// operator fun T.minus(arg: NDElement<T, F>): NDElement<T, F> = arg.transformIndexed { _, value -> * Create a nd-field with boxing generic buffer
// with(arg.context.field) { */
// this@minus - value fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F) = BufferNDField(shape, field, ::boxingBuffer)
// }
// } /**
// * Create a most suitable implementation for nd-field using reified class
// /** */
// * Reverse product operation inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F) = BufferNDField(shape, field, ::inlineBuffer)
// */ }
// operator fun T.times(arg: NDElement<T, F>): NDElement<T, F> = arg * this
//
// /**
// * Reverse division operation
// */
// operator fun T.div(arg: NDElement<T, F>): NDElement<T, F> = arg.transformIndexed { _, value ->
// with(arg.context.field) {
// this@div / value
// }
// }
} }
interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F>>, NDStructure<T> interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F>>, NDStructure<T> {
companion object {
/**
* Create a platform-optimized NDArray of doubles
*/
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> {
return NDField.real(shape).produce(initializer)
}
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim)) { initializer(it[0]) }
}
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
// val field = NDField.real(shape)
// return GenericNDElement(field, field.run(block))
// }
/**
* Simple boxing NDArray
*/
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> {
return NDField.generic(shape,field).produce(initializer)
}
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, crossinline initializer: F.(IntArray) -> T): NDElement<T, F> {
return NDField.inline(shape,field).produce(initializer)
}
}
}
inline fun <T, F : Field<T>> NDElement<T, F>.transformIndexed(crossinline action: F.(IntArray, T) -> T): NDElement<T, F> = context.produce { action(it, get(*it)) } inline fun <T, F : Field<T>> NDElement<T, F>.transformIndexed(crossinline action: F.(IntArray, T) -> T): NDElement<T, F> = context.produce { action(it, get(*it)) }
inline fun <T, F : Field<T>> NDElement<T, F>.transform(crossinline action: F.(T) -> T): NDElement<T, F> = context.produce { action(get(*it)) } inline fun <T, F : Field<T>> NDElement<T, F>.transform(crossinline action: F.(T) -> T): NDElement<T, F> = context.produce { action(get(*it)) }
/**
* Read-only [NDStructure] coupled to the context.
*/
class NDStructureElement<T, F : Field<T>>(override val context: NDField<T, F>, private val structure: NDStructure<T>) : NDElement<T, F>, NDStructure<T> by structure {
//TODO ensure structure is immutable
override val self: NDElement<T, F> get() = this
}
/** /**
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
@ -133,18 +147,14 @@ operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>
* Summation operation for [NDElement] and single element * Summation operation for [NDElement] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = transform { value -> operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = transform { value ->
with(context.field) { context.field.run { arg + value }
arg + value
}
} }
/** /**
* Subtraction operation between [NDElement] and single element * Subtraction operation between [NDElement] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = transform { value -> operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = transform { value ->
with(context.field) { context.field.run { arg - value }
arg - value
}
} }
/* prod and div */ /* prod and div */
@ -153,55 +163,53 @@ operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> =
* Product operation for [NDElement] and single element * Product operation for [NDElement] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = transform { value -> operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = transform { value ->
with(context.field) { context.field.run { arg * value }
arg * value
}
} }
/** /**
* Division operation between [NDElement] and single element * Division operation between [NDElement] and single element
*/ */
operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = transform { value -> operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = transform { value ->
with(context.field) { context.field.run { arg / value }
arg / value
}
} }
class GenericNDField<T : Any, F : Field<T>>(shape: IntArray, field: F) : NDField<T, F>(shape, field) {
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = NdStructure(shape, ::boxingBuffer) { field.initializer(it) } // /**
// * Reverse sum operation
// */
// operator fun T.plus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@plus + arg[index] }
// }
//
// /**
// * Reverse minus operation
// */
// operator fun T.minus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@minus - arg[index] }
// }
//
// /**
// * Reverse product operation
// */
// operator fun T.times(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@times * arg[index] }
// }
//
// /**
// * Reverse division operation
// */
// operator fun T.div(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@div / arg[index] }
// }
class GenericNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F) : NDField<T, F> {
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = GenericNDElement(this, produceStructure(initializer))
private inline fun produceStructure(crossinline initializer: F.(IntArray) -> T): NDStructure<T> = NdStructure(shape, ::boxingBuffer) { field.initializer(it) }
} }
//typealias NDFieldFactory<T> = (IntArray)->NDField<T> /**
* Read-only [NDStructure] coupled to the context.
object NDElements { */
/** class GenericNDElement<T, F : Field<T>>(override val context: NDField<T, F>, private val structure: NDStructure<T>) : NDElement<T, F>, NDStructure<T> by structure {
* Create a platform-optimized NDArray of doubles override val self: NDElement<T, F> get() = this
*/
fun realNDElement(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> {
return ExtendedNDField(shape, DoubleField).produce(initializer)
}
fun real1DElement(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> {
return realNDElement(intArrayOf(dim)) { initializer(it[0]) }
}
fun real2DElement(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return realNDElement(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real3DElement(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return realNDElement(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}
inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
val field = ExtendedNDField(shape, DoubleField)
return NDStructureElement(field, field.run(block))
}
/**
* Simple boxing NDArray
*/
fun <T : Any, F : Field<T>> create(field: F, shape: IntArray, initializer: (IntArray) -> T): NDElement<T, F> {
return GenericNDField(shape, field).produce { initializer(it) }
}
} }

View File

@ -19,7 +19,7 @@ interface MutableNDStructure<T> : NDStructure<T> {
operator fun set(index: IntArray, value: T) operator fun set(index: IntArray, value: T)
} }
fun <T> MutableNDStructure<T>.transformInPlace(action: (IntArray, T) -> T) { fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T) {
elements().forEach { (index, oldValue) -> elements().forEach { (index, oldValue) ->
this[index] = action(index, oldValue) this[index] = action(index, oldValue)
} }
@ -113,8 +113,8 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
} }
abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> { abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
protected abstract val buffer: B abstract val buffer: B
protected abstract val strides: Strides abstract val strides: Strides
override fun get(index: IntArray): T = buffer[strides.offset(index)] override fun get(index: IntArray): T = buffer[strides.offset(index)]
@ -153,7 +153,18 @@ class BufferNDStructure<T>(
result = 31 * result + buffer.hashCode() result = 31 * result + buffer.hashCode()
return result return result
} }
}
/**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
*/
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> {
return if (this is BufferNDStructure<T>) {
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else {
val strides = DefaultStrides(shape)
BufferNDStructure(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
}
} }
/** /**

View File

@ -1,16 +0,0 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.structures.NDElements.create
import kotlin.test.Test
import kotlin.test.assertEquals
class GenericNDFieldTest{
@Test
fun testStrides(){
val ndArray = create(DoubleField, intArrayOf(10,10)){(it[0]+it[1]).toDouble()}
assertEquals(ndArray[5,5], 10.0)
}
}

View File

@ -0,0 +1,13 @@
package scientifik.kmath.structures
import kotlin.test.Test
import kotlin.test.assertEquals
class NDFieldTest {
@Test
fun testStrides() {
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0)
}
}

View File

@ -1,16 +1,15 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.Norm import scientifik.kmath.operations.Norm
import scientifik.kmath.structures.NDElements.real import scientifik.kmath.structures.NDElement.Companion.real2D
import scientifik.kmath.structures.NDElements.real2DElement
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow import kotlin.math.pow
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class NumberNDFieldTest { class NumberNDFieldTest {
val array1 = real2DElement(3, 3) { i, j -> (i + j).toDouble() } val array1 = real2D(3, 3) { i, j -> (i + j).toDouble() }
val array2 = real2DElement(3, 3) { i, j -> (i - j).toDouble() } val array2 = real2D(3, 3) { i, j -> (i - j).toDouble() }
@Test @Test
fun testSum() { fun testSum() {
@ -27,7 +26,7 @@ class NumberNDFieldTest {
@Test @Test
fun testGeneration() { fun testGeneration() {
val array = real2DElement(3, 3) { i, j -> (i * 10 + j).toDouble() } val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() }
for (i in 0..2) { for (i in 0..2) {
for (j in 0..2) { for (j in 0..2) {
@ -52,7 +51,7 @@ class NumberNDFieldTest {
} }
@Test @Test
fun combineTest(){ fun combineTest() {
val division = array1.combine(array2, Double::div) val division = array1.combine(array2, Double::div)
} }
@ -64,7 +63,7 @@ class NumberNDFieldTest {
@Test @Test
fun testInternalContext() { fun testInternalContext() {
real(array1.shape) { NDField.real(array1.shape).run {
with(L2Norm) { with(L2Norm) {
1 + norm(array1) + exp(array2) 1 + norm(array1) + exp(array2)
} }

View File

@ -0,0 +1,7 @@
package scientifik.kmath.structures
//
//class LazyField<T>: Field<T> {
//}
//
//class LazyValue<T: Any>():

View File

@ -8,7 +8,7 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = LazyNDStructure(this) { initializer(field, it) } override fun produceStructure(initializer: F.(IntArray) -> T): NDStructure<T> = LazyNDStructure(this) { initializer(field, it) }
override fun add(a: NDElement<T, F>, b: NDElement<T, F>): NDElement<T, F> { override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index) val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index) val bDeferred = b.deferred(index)
@ -16,11 +16,11 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
} }
} }
override fun multiply(a: NDElement<T, F>, k: Double): NDElement<T, F> { override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> {
return LazyNDStructure(this) { index -> a.await(index) * k } return LazyNDStructure(this) { index -> a.await(index) * k }
} }
override fun multiply(a: NDElement<T, F>, b: NDElement<T, F>): NDElement<T, F> { override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index) val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index) val bDeferred = b.deferred(index)
@ -28,7 +28,7 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
} }
} }
override fun divide(a: NDElement<T, F>, b: NDElement<T, F>): NDElement<T, F> { override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index) val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index) val bDeferred = b.deferred(index)
@ -57,15 +57,15 @@ class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>,
} }
} }
fun <T> NDElement<T, *>.deferred(index: IntArray) = if (this is LazyNDStructure<T, *>) this.deferred(index) else CompletableDeferred(get(index)) fun <T> NDStructure<T>.deferred(index: IntArray) = if (this is LazyNDStructure<T, *>) this.deferred(index) else CompletableDeferred(get(index))
suspend fun <T> NDElement<T, *>.await(index: IntArray) = if (this is LazyNDStructure<T, *>) this.await(index) else get(index) suspend fun <T> NDStructure<T>.await(index: IntArray) = if (this is LazyNDStructure<T, *>) this.await(index) else get(index)
fun <T, F : Field<T>> NDElement<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> { fun <T, F : Field<T>> NDElement<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
return if (this is LazyNDStructure<T, F>) { return if (this is LazyNDStructure<T, F>) {
this this
} else { } else {
val context = LazyNDField(context.shape, context.field) val context = LazyNDField(context.shape, context.field, scope)
LazyNDStructure(context) { get(it) } LazyNDStructure(context) { get(it) }
} }
} }

View File

@ -9,7 +9,7 @@ class LazyNDFieldTest {
@Test @Test
fun testLazyStructure() { fun testLazyStructure() {
var counter = 0 var counter = 0
val regularStructure = NDElements.create(IntField, intArrayOf(2, 2, 2)) { it[0] + it[1] - it[2] } val regularStructure = NDElements.generic(intArrayOf(2, 2, 2), IntField) { it[0] + it[1] - it[2] }
val result = (regularStructure.lazy() + 2).transform { val result = (regularStructure.lazy() + 2).transform {
counter++ counter++
it * it it * it

View File

@ -12,5 +12,5 @@ include(
":kmath-core", ":kmath-core",
":kmath-io", ":kmath-io",
":kmath-coroutines", ":kmath-coroutines",
":kmath-jmh" ":benchmarks"
) )