Fixes to tests and lazy structures

This commit is contained in:
Alexander Nozik 2019-01-04 20:23:32 +03:00
parent c0a43c1bd1
commit 600d8a64b8
17 changed files with 185 additions and 138 deletions

View File

@ -5,6 +5,7 @@ plugins {
}
dependencies {
compile project(':kmath-core')
compile project(":kmath-core")
compile project(":kmath-coroutines")
//jmh project(':kmath-core')
}

View File

@ -1,15 +1,16 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.RealField
import kotlin.system.measureTimeMillis
fun main(args: Array<String>) {
val dim = 1000
val n = 10000
val n = 100
val bufferedField = NDField.buffered(intArrayOf(dim, dim), DoubleField)
val bufferedField = NDField.buffered(intArrayOf(dim, dim), RealField)
val specializedField = NDField.real(intArrayOf(dim, dim))
val genericField = NDField.generic(intArrayOf(dim, dim), DoubleField)
val genericField = NDField.generic(intArrayOf(dim, dim), RealField)
val lazyNDField = NDField.lazy(intArrayOf(dim, dim), RealField)
// val action: NDField<Double, DoubleField, NDStructure<Double>>.() -> Unit = {
// var res = one
@ -55,6 +56,23 @@ fun main(args: Array<String>) {
println("Specialized addition completed in $specializedTime millis")
val lazyTime = measureTimeMillis {
val tr : RealField.(Double)->Double = {arg->
var r = arg
repeat(n) {
r += 1.0
}
r
}
lazyNDField.run {
val res = one.map(tr)
res.elements().sumByDouble { it.second }
}
}
println("Lazy addition completed in $lazyTime millis")
val genericTime = measureTimeMillis {
//genericField.run(action)
genericField.run {

View File

@ -42,7 +42,7 @@ class PhantomBin<T : Comparable<T>>(val template: BinTemplate<T>, override val v
/**
* Immutable histogram with explicit structure for content and additional external bin description.
* Bin search is slow, but full histogram algebra is supported.
* @param bins map a template into structure index
* @param bins transform a template into structure index
*/
class PhantomHistogram<T : Comparable<T>>(
val bins: Map<BinTemplate<T>, IntArray>,

View File

@ -1,7 +1,7 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import scientifik.kmath.structures.*
import kotlin.math.absoluteValue
@ -184,7 +184,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
}
class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold: Double = DEFAULT_TOO_SMALL) :
LUDecomposition<Double, DoubleField>(matrix) {
LUDecomposition<Double, RealField>(matrix) {
override fun isSingular(value: Double): Boolean {
return value.absoluteValue < singularityThreshold
}
@ -197,9 +197,9 @@ class RealLUDecomposition(matrix: RealMatrix, private val singularityThreshold:
/** Specialized solver. */
object RealLUSolver : LinearSolver<Double, DoubleField> {
object RealLUSolver : LinearSolver<Double, RealField> {
fun decompose(mat: Matrix<Double, DoubleField>, threshold: Double = 1e-11): RealLUDecomposition =
fun decompose(mat: Matrix<Double, RealField>, threshold: Double = 1e-11): RealLUDecomposition =
RealLUDecomposition(mat, threshold)
override fun solve(a: RealMatrix, b: RealMatrix): RealMatrix {

View File

@ -1,8 +1,8 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Norm
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring
import scientifik.kmath.structures.asSequence
import scientifik.kmath.structures.boxingBuffer
@ -61,5 +61,5 @@ object VectorL2Norm : Norm<Vector<out Number, *>, Double> {
}
}
typealias RealVector = Vector<Double, DoubleField>
typealias RealMatrix = Matrix<Double, DoubleField>
typealias RealVector = Vector<Double, RealField>
typealias RealMatrix = Matrix<Double, RealField>

View File

@ -1,6 +1,6 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement
@ -42,8 +42,8 @@ interface MatrixSpace<T : Any, R : Ring<T>> : Space<Matrix<T, R>> {
/**
* Non-boxing double matrix
*/
fun real(rows: Int, columns: Int): MatrixSpace<Double, DoubleField> =
StructureMatrixSpace(rows, columns, DoubleField, DoubleBufferFactory)
fun real(rows: Int, columns: Int): MatrixSpace<Double, RealField> =
StructureMatrixSpace(rows, columns, RealField, DoubleBufferFactory)
/**
* A structured matrix with custom buffer

View File

@ -1,6 +1,6 @@
package scientifik.kmath.linear
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.RealField
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.SpaceElement
import scientifik.kmath.structures.*
@ -34,13 +34,13 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
companion object {
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, DoubleField>>()
private val realSpaceCache = HashMap<Int, BufferVectorSpace<Double, RealField>>()
/**
* Non-boxing double vector space
*/
fun real(size: Int): BufferVectorSpace<Double, DoubleField> {
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, DoubleField, DoubleBufferFactory) }
fun real(size: Int): BufferVectorSpace<Double, RealField> {
return realSpaceCache.getOrPut(size) { BufferVectorSpace(size, RealField, DoubleBufferFactory) }
}
/**
@ -79,10 +79,10 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, Vector<T, S>, V
fun <T : Any, S : Space<T>> generic(size: Int, field: S, initializer: (Int) -> T): Vector<T, S> =
VectorSpace.buffered(size, field).produceElement(initializer)
fun real(size: Int, initializer: (Int) -> Double): Vector<Double, DoubleField> =
fun real(size: Int, initializer: (Int) -> Double): Vector<Double, RealField> =
VectorSpace.real(size).produceElement(initializer)
fun ofReal(vararg elements: Double): Vector<Double, DoubleField> =
fun ofReal(vararg elements: Double): Vector<Double, RealField> =
VectorSpace.real(elements.size).produceElement { elements[it] }
}

View File

@ -17,12 +17,12 @@ interface ExtendedField<T : Any> :
*
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/
inline class Real(val value: Double) : FieldElement<Double, Real, DoubleField> {
inline class Real(val value: Double) : FieldElement<Double, Real, RealField> {
override fun unwrap(): Double = value
override fun Double.wrap(): Real = Real(value)
override val context get() = DoubleField
override val context get() = RealField
companion object {
@ -32,7 +32,7 @@ inline class Real(val value: Double) : FieldElement<Double, Real, DoubleField> {
/**
* A field for double without boxing. Does not produce appropriate field element
*/
object DoubleField : AbstractField<Double>(),ExtendedField<Double>, Norm<Double, Double> {
object RealField : AbstractField<Double>(),ExtendedField<Double>, Norm<Double, Double> {
override val zero: Double = 0.0
override fun add(a: Double, b: Double): Double = a + b
override fun multiply(a: Double, @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") b: Double): Double = a * b

View File

@ -5,18 +5,19 @@ import scientifik.kmath.operations.FieldElement
abstract class StridedNDField<T, F : Field<T>>(shape: IntArray, elementField: F) :
AbstractNDField<T, F, NDBuffer<T>>(shape, elementField) {
abstract val bufferFactory: BufferFactory<T>
val strides = DefaultStrides(shape)
abstract fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T>
}
class BufferNDField<T, F : Field<T>>(
shape: IntArray,
elementField: F,
override val bufferFactory: BufferFactory<T>
) :
StridedNDField<T, F>(shape, elementField) {
val bufferFactory: BufferFactory<T>
) : StridedNDField<T, F>(shape, elementField) {
override fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Element strides are not the same as context strides")
@ -32,22 +33,25 @@ class BufferNDField<T, F : Field<T>>(
bufferFactory(strides.linearSize) { offset -> elementField.initializer(strides.index(offset)) })
@Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<T>.map(crossinline transform: F.(T) -> T): BufferNDElement<T, F> {
check(this)
override inline fun map(arg: NDBuffer<T>, crossinline transform: F.(T) -> T): BufferNDElement<T, F> {
check(arg)
return BufferNDElement(
this@BufferNDField,
bufferFactory(strides.linearSize) { offset -> elementField.transform(buffer[offset]) })
this,
bufferFactory(arg.strides.linearSize) { offset -> elementField.transform(arg.buffer[offset]) })
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<T>.mapIndexed(crossinline transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> {
check(this)
override inline fun mapIndexed(
arg: NDBuffer<T>,
crossinline transform: F.(index: IntArray, T) -> T
): BufferNDElement<T, F> {
check(arg)
return BufferNDElement(
this@BufferNDField,
bufferFactory(strides.linearSize) { offset ->
this,
bufferFactory(arg.strides.linearSize) { offset ->
elementField.transform(
strides.index(offset),
buffer[offset]
arg.strides.index(offset),
arg.buffer[offset]
)
})
}
@ -105,11 +109,11 @@ class BufferNDElement<T, F : Field<T>>(override val context: StridedNDField<T, F
override fun elements(): Sequence<Pair<IntArray, T>> =
strides.indices().map { it to get(it) }
override fun map(action: F.(T) -> T): BufferNDElement<T, F> =
context.run { map(action) }
override fun map(action: F.(T) -> T) =
context.run { map(this@BufferNDElement, action) }.wrap()
override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> =
context.run { mapIndexed(transform) }
override fun mapIndexed(transform: F.(index: IntArray, T) -> T) =
context.run { mapIndexed(this@BufferNDElement, transform) }.wrap()
}
/**

View File

@ -6,7 +6,7 @@ import scientifik.kmath.operations.PowerOperations
import scientifik.kmath.operations.TrigonometricOperations
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<out T>> :
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<T>> :
NDField<T, F, N>,
TrigonometricOperations<N>,
PowerOperations<N>,
@ -16,7 +16,7 @@ interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<out T>>
/**
* NDField that supports [ExtendedField] operations on its elements
*/
class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>, N : NDStructure<out T>>(private val ndField: NDField<T, F, N>) :
class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>, N : NDStructure<T>>(private val ndField: NDField<T, F, N>) :
ExtendedNDField<T, F, N>, NDField<T, F, N> by ndField {
override val shape: IntArray get() = ndField.shape

View File

@ -1,8 +1,8 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
import scientifik.kmath.operations.RealField
interface NDElement<T, F : Field<T>> : NDStructure<T> {
@ -17,7 +17,7 @@ object NDElements {
/**
* Create a optimized NDArray of doubles
*/
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) =
fun real(shape: IntArray, initializer: RealField.(IntArray) -> Double = { 0.0 }) =
NDField.real(shape).produce(initializer)

View File

@ -23,9 +23,9 @@ interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N> {
fun produce(initializer: F.(IntArray) -> T): N
fun N.map(transform: F.(T) -> T): N
fun map(arg: N, transform: F.(T) -> T): N
fun N.mapIndexed(transform: F.(index: IntArray, T) -> T): N
fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N
fun combine(a: N, b: N, transform: F.(T, T) -> T): N
@ -87,11 +87,11 @@ abstract class AbstractNDField<T, F : Field<T>, N : NDStructure<T>>(
override val one: N by lazy { produce { one } }
final override operator fun Function1<T, T>.invoke(structure: N) = structure.map { value -> this@invoke(value) }
final override operator fun N.plus(arg: T) = this.map { value -> elementField.run { arg + value } }
final override operator fun N.minus(arg: T) = this.map { value -> elementField.run { arg - value } }
final override operator fun N.times(arg: T) = this.map { value -> elementField.run { arg * value } }
final override operator fun N.div(arg: T) = this.map { value -> elementField.run { arg / value } }
final override operator fun Function1<T, T>.invoke(structure: N) = map(structure) { value -> this@invoke(value) }
final override operator fun N.plus(arg: T) = map(this) { value -> elementField.run { arg + value } }
final override operator fun N.minus(arg: T) = map(this) { value -> elementField.run { arg - value } }
final override operator fun N.times(arg: T) = map(this) { value -> elementField.run { arg * value } }
final override operator fun N.div(arg: T) = map(this) { value -> elementField.run { arg / value } }
final override operator fun T.plus(arg: N) = arg + this
final override operator fun T.minus(arg: N) = arg - this
@ -109,7 +109,7 @@ abstract class AbstractNDField<T, F : Field<T>, N : NDStructure<T>>(
* Multiply all elements by cinstant
*/
override fun multiply(a: N, k: Double): N =
a.map { it * k }
map(a) { it * k }
/**
@ -140,17 +140,16 @@ class GenericNDField<T : Any, F : Field<T>>(
shape: IntArray,
elementField: F,
val bufferFactory: BufferFactory<T> = ::boxingBuffer
) :
AbstractNDField<T, F, NDStructure<T>>(shape, elementField) {
) : AbstractNDField<T, F, NDStructure<T>>(shape, elementField) {
override fun produce(initializer: F.(IntArray) -> T): NDStructure<T> =
ndStructure(shape, bufferFactory) { elementField.initializer(it) }
override fun NDStructure<T>.map(transform: F.(T) -> T): NDStructure<T> =
produce { index -> transform(get(index)) }
override fun map(arg: NDStructure<T>, transform: F.(T) -> T): NDStructure<T> =
produce { index -> transform(arg.get(index)) }
override fun NDStructure<T>.mapIndexed(transform: F.(index: IntArray, T) -> T): NDStructure<T> =
produce { index -> transform(index, get(index)) }
override fun mapIndexed(arg: NDStructure<T>, transform: F.(index: IntArray, T) -> T): NDStructure<T> =
produce { index -> transform(index, arg.get(index)) }
override fun combine(a: NDStructure<T>, b: NDStructure<T>, transform: F.(T, T) -> T): NDStructure<T> =
produce { index -> transform(a[index], b[index]) }

View File

@ -1,41 +1,47 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.RealField
typealias RealNDElement = BufferNDElement<Double, DoubleField>
typealias RealNDElement = BufferNDElement<Double, RealField>
class RealNDField(shape: IntArray) :
StridedNDField<Double, DoubleField>(shape, DoubleField),
ExtendedNDField<Double, DoubleField, NDBuffer<Double>> {
StridedNDField<Double, RealField>(shape, RealField),
ExtendedNDField<Double, RealField, NDBuffer<Double>> {
override val bufferFactory: BufferFactory<Double>
get() = DoubleBufferFactory
override fun buildBuffer(size: Int, initializer: (Int) -> Double): Buffer<Double> =
DoubleBuffer(DoubleArray(size, initializer))
/**
* Inline map an NDStructure to
* Inline transform an NDStructure to
*/
@Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<Double>.map(crossinline transform: DoubleField.(Double) -> Double): RealNDElement {
check(this)
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferNDElement(this@RealNDField, DoubleBuffer(array))
override inline fun map(
arg: NDBuffer<Double>,
crossinline transform: RealField.(Double) -> Double
): RealNDElement {
check(arg)
val array = DoubleArray(arg.strides.linearSize) { offset -> RealField.transform(arg.buffer[offset]) }
return BufferNDElement(this, DoubleBuffer(array))
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(crossinline initializer: DoubleField.(IntArray) -> Double): RealNDElement {
override inline fun produce(crossinline initializer: RealField.(IntArray) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> elementField.initializer(strides.index(offset)) }
return BufferNDElement(this, DoubleBuffer(array))
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun NDBuffer<Double>.mapIndexed(crossinline transform: DoubleField.(index: IntArray, Double) -> Double): BufferNDElement<Double, DoubleField> {
check(this)
override inline fun mapIndexed(
arg: NDBuffer<Double>,
crossinline transform: RealField.(index: IntArray, Double) -> Double
): BufferNDElement<Double, RealField> {
check(arg)
return BufferNDElement(
this@RealNDField,
bufferFactory(strides.linearSize) { offset ->
this,
buildBuffer(arg.strides.linearSize) { offset ->
elementField.transform(
strides.index(offset),
buffer[offset]
arg.strides.index(offset),
arg.buffer[offset]
)
})
}
@ -44,23 +50,23 @@ class RealNDField(shape: IntArray) :
override inline fun combine(
a: NDBuffer<Double>,
b: NDBuffer<Double>,
crossinline transform: DoubleField.(Double, Double) -> Double
): BufferNDElement<Double, DoubleField> {
crossinline transform: RealField.(Double, Double) -> Double
): BufferNDElement<Double, RealField> {
check(a, b)
return BufferNDElement(
this,
bufferFactory(strides.linearSize) { offset -> elementField.transform(a.buffer[offset], b.buffer[offset]) })
buildBuffer(strides.linearSize) { offset -> elementField.transform(a.buffer[offset], b.buffer[offset]) })
}
override fun power(arg: NDBuffer<Double>, pow: Double) = arg.map { power(it, pow) }
override fun power(arg: NDBuffer<Double>, pow: Double) = map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Double>) = arg.map { exp(it) }
override fun exp(arg: NDBuffer<Double>) = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Double>) = arg.map { ln(it) }
override fun ln(arg: NDBuffer<Double>) = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Double>) = arg.map { sin(it) }
override fun sin(arg: NDBuffer<Double>) = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Double>) = arg.map { cos(it) }
override fun cos(arg: NDBuffer<Double>) = map(arg) { cos(it) }
//
// override fun NDBuffer<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() }
//
@ -75,7 +81,7 @@ class RealNDField(shape: IntArray) :
/**
* Fast element production using function inlining
*/
inline fun StridedNDField<Double, DoubleField>.produceInline(crossinline initializer: DoubleField.(Int) -> Double): RealNDElement {
inline fun StridedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> elementField.initializer(offset) }
return BufferNDElement(this, DoubleBuffer(array))
}

View File

@ -2,14 +2,14 @@ package scientifik.kmath.expressions
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class ExpressionFieldTest {
@Test
fun testExpression() {
val context = ExpressionField(DoubleField)
val context = ExpressionField(RealField)
val expression = with(context) {
val x = variable("x", 2.0)
x * x + 2 * x + 1.0
@ -36,7 +36,7 @@ class ExpressionFieldTest {
return x * x + 2 * x + one
}
val expression = ExpressionField(DoubleField).expression()
val expression = ExpressionField(RealField).expression()
assertEquals(expression("x" to 1.0), 4.0)
}
@ -47,7 +47,7 @@ class ExpressionFieldTest {
x * x + 2 * x + 1.0
}
val expression = ExpressionField(DoubleField).expressionBuilder()
val expression = ExpressionField(RealField).expressionBuilder()
assertEquals(expression("x" to 1.0), 4.0)
}
}

View File

@ -6,7 +6,7 @@ import kotlin.test.assertEquals
class RealFieldTest {
@Test
fun testSqrt() {
val sqrt = with(DoubleField) {
val sqrt = with(RealField) {
sqrt(25 * one)
}
assertEquals(5.0, sqrt)

View File

@ -2,54 +2,80 @@ package scientifik.kmath.structures
import kotlinx.coroutines.*
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) :
BufferNDField<T, F>(shape, field, ::boxingBuffer) {
AbstractNDField<T, F, NDStructure<T>>(shape, field) {
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index)
aDeferred.await() + bDeferred.await()
override val zero by lazy { produce { zero } }
override val one by lazy { produce { one } }
override fun produce(initializer: F.(IntArray) -> T) =
LazyNDStructure(this) { elementField.initializer(it) }
override fun mapIndexed(
arg: NDStructure<T>,
transform: F.(index: IntArray, T) -> T
): LazyNDStructure<T, F> {
check(arg)
return if (arg is LazyNDStructure<T, *>) {
LazyNDStructure(this) { index ->
this.elementField.transform(index, arg.function(index))
}
} else {
LazyNDStructure(this) { elementField.transform(it, arg.await(it)) }
}
}
override fun multiply(a: NDStructure<T>, k: Double): NDElements<T, F> {
return LazyNDStructure(this) { index -> a.await(index) * k }
}
override fun map(arg: NDStructure<T>, transform: F.(T) -> T) =
mapIndexed(arg) { _, t -> transform(t) }
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index)
aDeferred.await() * bDeferred.await()
override fun combine(a: NDStructure<T>, b: NDStructure<T>, transform: F.(T, T) -> T): LazyNDStructure<T, F> {
check(a, b)
return if (a is LazyNDStructure<T, *> && b is LazyNDStructure<T, *>) {
LazyNDStructure(this@LazyNDField) { index ->
elementField.transform(
a.function(index),
b.function(index)
)
}
} else {
LazyNDStructure(this@LazyNDField) { elementField.transform(a.await(it), b.await(it)) }
}
}
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index)
aDeferred.await() / bDeferred.await()
fun NDStructure<T>.lazy(): LazyNDStructure<T, F> {
check(this)
return if (this is LazyNDStructure<T, *>) {
LazyNDStructure(this@LazyNDField, function)
} else {
LazyNDStructure(this@LazyNDField) { get(it) }
}
}
}
class LazyNDStructure<T, F : Field<T>>(
override val context: LazyNDField<T, F>,
val function: suspend F.(IntArray) -> T
) : NDElements<T, F>, NDStructure<T> {
override val self: NDElements<T, F> get() = this
val function: suspend (IntArray) -> T
) : FieldElement<NDStructure<T>, LazyNDStructure<T, F>, LazyNDField<T, F>>, NDElement<T, F> {
override fun unwrap(): NDStructure<T> = this
override fun NDStructure<T>.wrap(): LazyNDStructure<T, F> = LazyNDStructure(context) { await(it) }
override val shape: IntArray get() = context.shape
override val elementField: F get() = context.elementField
override fun mapIndexed(transform: F.(index: IntArray, T) -> T): NDElement<T, F> =
context.run { mapIndexed(this@LazyNDStructure, transform) }
private val cache = HashMap<IntArray, Deferred<T>>()
fun deferred(index: IntArray) = cache.getOrPut(index) {
context.scope.async(context = Dispatchers.Math) {
function.invoke(
context.elementField,
index
)
function(index)
}
}
@ -61,7 +87,10 @@ class LazyNDStructure<T, F : Field<T>>(
override fun elements(): Sequence<Pair<IntArray, T>> {
val strides = DefaultStrides(shape)
return strides.indices().map { index -> index to runBlocking { await(index) } }
val res = runBlocking {
strides.indices().toList().map { index -> index to await(index) }
}
return res.asSequence()
}
}
@ -71,21 +100,11 @@ fun <T> NDStructure<T>.deferred(index: IntArray) =
suspend fun <T> NDStructure<T>.await(index: IntArray) =
if (this is LazyNDStructure<T, *>) this.await(index) else get(index)
fun <T, F : Field<T>> NDElements<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
return if (this is LazyNDStructure<T, F>) {
this
} else {
val context = LazyNDField(context.shape, context.field, scope)
LazyNDStructure(context) { get(it) }
}
}
inline fun <T, F : Field<T>> LazyNDStructure<T, F>.mapIndexed(crossinline action: suspend F.(IntArray, T) -> T) =
LazyNDStructure(context) { index ->
action.invoke(this, index, await(index))
}
fun <T : Any, F : Field<T>> NDField.Companion.lazy(shape: IntArray, field: F, scope: CoroutineScope = GlobalScope) =
LazyNDField(shape, field, scope)
inline fun <T, F : Field<T>> LazyNDStructure<T, F>.map(crossinline action: suspend F.(T) -> T) =
LazyNDStructure(context) { index ->
action.invoke(this, await(index))
}
fun <T, F : Field<T>> NDStructure<T>.lazy(field: F, scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
val context: LazyNDField<T, F> = LazyNDField(shape, field, scope)
return LazyNDStructure(context) { get(it) }
}

View File

@ -10,7 +10,7 @@ class LazyNDFieldTest {
fun testLazyStructure() {
var counter = 0
val regularStructure = NDField.generic(intArrayOf(2, 2, 2), IntField).produce { it[0] + it[1] - it[2] }
val result = (regularStructure.lazy() + 2).transform {
val result = (regularStructure.lazy(IntField) + 2).map {
counter++
it * it
}