Major refactoring of types for algebra elements. Element type could differ from field type.

This commit is contained in:
Alexander Nozik 2019-01-03 20:48:10 +03:00
parent 622a6a7756
commit 53d752dee8
17 changed files with 378 additions and 352 deletions

View File

@ -8,12 +8,12 @@ fun main(args: Array<String>) {
val n = 6000 val n = 6000
val structure = NdStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 } val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
structure.map { it + 1 } // warm-up structure.mapToBuffer { it + 1 } // warm-up
val time1 = measureTimeMillis { val time1 = measureTimeMillis {
val res = structure.map { it + 1 } val res = structure.mapToBuffer { it + 1 }
} }
println("Structure mapping finished in $time1 millis") println("Structure mapping finished in $time1 millis")

View File

@ -104,7 +104,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
val m = matrix.numCols val m = matrix.numCols
val pivot = IntArray(matrix.numRows) val pivot = IntArray(matrix.numRows)
//TODO fix performance //TODO fix performance
val lu: MutableNDStructure<T> = MutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] } val lu: MutableNDStructure<T> = mutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] }
with(matrix.context.ring) { with(matrix.context.ring) {

View File

@ -122,11 +122,11 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> { override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> {
return if (rows == rowNum && columns == colNum) { return if (rows == rowNum && columns == colNum) {
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) } val structure = ndStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(this, structure) StructureMatrix(this, structure)
} else { } else {
val context = StructureMatrixSpace(rows, columns, ring, bufferFactory) val context = StructureMatrixSpace(rows, columns, ring, bufferFactory)
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) } val structure = ndStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
StructureMatrix(context, structure) StructureMatrix(context, structure)
} }
} }

View File

@ -16,13 +16,18 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
val space: S val space: S
fun produce(initializer: (Int) -> T): Vector<T, S> fun produce(initializer: (Int) -> T): Point<T>
override val zero: Vector<T, S> get() = produce { space.zero } /**
* Produce a space-element of this vector space for expressions
*/
fun produceElement(initializer: (Int) -> T): Vector<T, S>
override fun add(a: Point<T>, b: Point<T>): Vector<T, S> = produce { with(space) { a[it] + b[it] } } override val zero: Point<T> get() = produce { space.zero }
override fun multiply(a: Point<T>, k: Double): Vector<T, S> = produce { with(space) { a[it] * k } } override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
override fun multiply(a: Point<T>, k: Double): Point<T> = produce { with(space) { a[it] * k } }
//TODO add basis //TODO add basis
@ -53,7 +58,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
/** /**
* A point coupled to the linear space * A point coupled to the linear space
*/ */
interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T, S>>, Point<T> { interface Vector<T : Any, S : Space<T>> : SpaceElement<Vector<T,S>, VectorSpace<T, S>>, Point<T> {
override val size: Int get() = context.size override val size: Int get() = context.size
override operator fun plus(b: Point<T>): Vector<T, S> = context.add(self, b) override operator fun plus(b: Point<T>): Vector<T, S> = context.add(self, b)
@ -65,11 +70,11 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T,
/** /**
* Create vector with custom field * Create vector with custom field
*/ */
fun <T : Any, F : Space<T>> generic(size: Int, field: F, initializer: (Int) -> T) = fun <T : Any, S : Space<T>> generic(size: Int, field: S, initializer: (Int) -> T): Vector<T, S> =
VectorSpace.buffered(size, field).produce(initializer) VectorSpace.buffered(size, field).produceElement(initializer)
fun real(size: Int, initializer: (Int) -> Double) = VectorSpace.real(size).produce(initializer) fun real(size: Int, initializer: (Int) -> Double): Vector<Double,DoubleField> = VectorSpace.real(size).produceElement(initializer)
fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce{elements[it]} fun ofReal(vararg elements: Double): Vector<Double,DoubleField> = VectorSpace.real(elements.size).produceElement { elements[it] }
} }
} }
@ -79,7 +84,8 @@ data class BufferVectorSpace<T : Any, S : Space<T>>(
override val space: S, override val space: S,
val bufferFactory: BufferFactory<T> val bufferFactory: BufferFactory<T>
) : VectorSpace<T, S> { ) : VectorSpace<T, S> {
override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer)) override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
} }
@ -95,7 +101,7 @@ data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace
return buffer[index] return buffer[index]
} }
override val self: BufferVector<T, S> get() = this override fun getSelf(): BufferVector<T, S
override fun iterator(): Iterator<T> = (0 until size).map { buffer[it] }.iterator() override fun iterator(): Iterator<T> = (0 until size).map { buffer[it] }.iterator()

View File

@ -1,23 +1,6 @@
package scientifik.kmath.operations package scientifik.kmath.operations
/**
* The generic mathematics elements which is able to store its context
* @param T the self type of the element
* @param S the type of mathematical context for this element
*/
interface MathElement<T, S> {
/**
* Self value. Needed for static type checking.
*/
val self: T
/**
* The context this element belongs to
*/
val context: S
}
/** /**
* A general interface representing linear context of some kind. * A general interface representing linear context of some kind.
* The context defines sum operation for its elements and multiplication by real value. * The context defines sum operation for its elements and multiplication by real value.
@ -53,19 +36,8 @@ interface Space<T> {
//TODO move to external extensions when they are available //TODO move to external extensions when they are available
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right } fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
}
/** fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
* The element of linear context
* @param T self type of the element. Needed for static type checking
* @param S the type of space
*/
interface SpaceElement<T, S : Space<T>> : MathElement<T, S> {
operator fun plus(b: T): T = context.add(self, b)
operator fun minus(b: T): T = context.add(self, context.multiply(b, -1.0))
operator fun times(k: Number): T = context.multiply(self, k.toDouble())
operator fun div(k: Number): T = context.multiply(self, 1.0 / k.toDouble())
} }
/** /**
@ -86,15 +58,6 @@ interface Ring<T> : Space<T> {
} }
/**
* Ring element
*/
interface RingElement<T, S : Ring<T>> : SpaceElement<T, S> {
override val context: S
operator fun times(b: T): T = context.multiply(self, b)
}
/** /**
* Four operations algebra * Four operations algebra
*/ */
@ -110,12 +73,3 @@ interface Field<T> : Ring<T> {
operator fun T.minus(b: Number) = this.minus(b * one) operator fun T.minus(b: Number) = this.minus(b * one)
operator fun Number.minus(b: T) = -b + this operator fun Number.minus(b: T) = -b + this
} }
/**
* Field element
*/
interface FieldElement<T, F : Field<T>> : RingElement<T, F> {
override val context: F
operator fun div(b: T): T = context.divide(self, b)
}

View File

@ -0,0 +1,49 @@
package scientifik.kmath.operations
/**
* The generic mathematics elements which is able to store its context
* @param S the type of mathematical context for this element
*/
interface MathElement<S> {
/**
* The context this element belongs to
*/
val context: S
}
/**
* The element of linear context
* @param T the type of space operation results
* @param I self type of the element. Needed for static type checking
* @param S the type of space
*/
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S> {
/**
* Self value. Needed for static type checking.
*/
fun unwrap(): T
fun T.wrap(): I
operator fun plus(b: T) = context.add(unwrap(), b).wrap()
operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap()
operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
}
/**
* Ring element
*/
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
override val context: R
operator fun times(b: T) = context.multiply(unwrap(), b).wrap()
}
/**
* Field element
*/
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
override val context: F
operator fun div(b: T) = context.divide(unwrap(), b).wrap()
}

View File

@ -16,15 +16,18 @@ object ComplexField : Field<Complex> {
override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
override fun divide(a: Complex, b: Complex): Complex = Complex(a.re * b.re + a.im * b.im, a.re * b.im - a.im * b.re) / b.square override fun divide(a: Complex, b: Complex): Complex {
val norm = b.square
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
}
operator fun Double.plus(c: Complex) = this.toComplex() + c operator fun Double.plus(c: Complex) = add(this.toComplex(), c)
operator fun Double.minus(c: Complex) = this.toComplex() - c operator fun Double.minus(c: Complex) = add(this.toComplex(), -c)
operator fun Complex.plus(d: Double) = d + this operator fun Complex.plus(d: Double) = d + this
operator fun Complex.minus(d: Double) = this - d.toComplex() operator fun Complex.minus(d: Double) = add(this, -d.toComplex())
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this) operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this)
} }
@ -34,6 +37,7 @@ object ComplexField : Field<Complex> {
*/ */
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> { data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> {
override val self: Complex get() = this override val self: Complex get() = this
override val context: ComplexField override val context: ComplexField
get() = ComplexField get() = ComplexField

View File

@ -12,42 +12,18 @@ interface ExtendedField<T : Any> :
ExponentialOperations<T> ExponentialOperations<T>
/**
* Field for real values
*/
object RealField : ExtendedField<Real>, Norm<Real, Real> {
override val zero: Real = Real(0.0)
override fun add(a: Real, b: Real): Real = Real(a.value + b.value)
override val one: Real = Real(1.0)
override fun multiply(a: Real, b: Real): Real = Real(a.value * b.value)
override fun multiply(a: Real, k: Double): Real = Real(a.value * k)
override fun divide(a: Real, b: Real): Real = Real(a.value / b.value)
override fun sin(arg: Real): Real = Real(kotlin.math.sin(arg.value))
override fun cos(arg: Real): Real = Real(kotlin.math.cos(arg.value))
override fun power(arg: Real, pow: Double): Real = Real(arg.value.pow(pow))
override fun exp(arg: Real): Real = Real(kotlin.math.exp(arg.value))
override fun ln(arg: Real): Real = Real(kotlin.math.ln(arg.value))
override fun norm(arg: Real): Real = Real(kotlin.math.abs(arg.value))
}
/** /**
* Real field element wrapping double. * Real field element wrapping double.
* *
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586 * TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
*/ */
inline class Real(val value: Double) : FieldElement<Real, RealField> { inline class Real(val value: Double) : FieldElement<Double, Real, DoubleField> {
override fun unwrap(): Double = value
//values are dynamically calculated to save memory override fun Double.wrap(): Real = Real(value)
override val self
get() = this
override val context override val context get() = DoubleField
get() = RealField
companion object { companion object {
@ -79,11 +55,31 @@ object DoubleField : ExtendedField<Double>, Norm<Double, Double> {
/** /**
* A field for double without boxing. Does not produce appropriate field element * A field for double without boxing. Does not produce appropriate field element
*/ */
object IntField : Field<Int>{ object IntField : Field<Int> {
override val zero: Int = 0 override val zero: Int = 0
override fun add(a: Int, b: Int): Int = a + b override fun add(a: Int, b: Int): Int = a + b
override fun multiply(a: Int, b: Int): Int = a * b override fun multiply(a: Int, b: Int): Int = a * b
override fun multiply(a: Int, k: Double): Int = (k*a).toInt() override fun multiply(a: Int, k: Double): Int = (k * a).toInt()
override val one: Int = 1 override val one: Int = 1
override fun divide(a: Int, b: Int): Int = a / b override fun divide(a: Int, b: Int): Int = a / b
} }
//interface FieldAdapter<T, R> : Field<R> {
//
// val field: Field<T>
//
// abstract fun T.evolve(): R
// abstract fun R.devolve(): T
//
// override val zero get() = field.zero.evolve()
// override val one get() = field.zero.evolve()
//
// override fun add(a: R, b: R): R = field.add(a.devolve(), b.devolve()).evolve()
//
// override fun multiply(a: R, k: Double): R = field.multiply(a.devolve(), k).evolve()
//
//
// override fun multiply(a: R, b: R): R = field.multiply(a.devolve(), b.devolve()).evolve()
//
// override fun divide(a: R, b: R): R = field.divide(a.devolve(), b.devolve()).evolve()
//}

View File

@ -1,45 +1,65 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F> { open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F, NDBuffer<T>> {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> { override fun produce(initializer: F.(IntArray) -> T) =
return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }) BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
}
open fun produceBuffered(initializer: F.(Int) -> T) = open fun NDBuffer<T>.map(transform: F.(T) -> T) =
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) }) BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(buffer[offset]) })
// override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { open fun NDBuffer<T>.mapIndexed(transform: F.(index: IntArray, T) -> T) =
// checkShape(a, b) BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(strides.index(offset), buffer[offset]) })
// return if (a is BufferNDElement<T, *> && b is BufferNDElement<T, *>) {
// BufferNDElement(this,bufferFactory(strides.linearSize){i-> field.run { a.buffer[i] + b.buffer[i]}}) open fun combine(a: NDBuffer<T>, b: NDBuffer<T>, transform: F.(T, T) -> T) =
// } else { BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.transform(a[offset], b[offset]) })
// produce { field.run { a[it] + b[it] } }
// } /**
// } * Convert any [NDStructure] to buffered structure using strides from this context.
// * If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes
// override fun NDStructure<T>.plus(b: Number): NDElement<T,F> { */
// checkShape(this) fun NDStructure<T>.toBuffer(): NDBuffer<T> =
// return if (this is BufferNDElement<T, *>) { this as? NDBuffer<T> ?: produce { index -> get(index) }
// BufferNDElement(this@BufferNDField,bufferFactory(strides.linearSize){i-> field.run { this@plus.buffer[i] + b}})
// } else { override val zero: NDBuffer<T> by lazy { produce { field.zero } }
// produce {index -> field.run { this@plus[index] + b } }
// } override fun add(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> add(aValue, bValue) }
// }
override fun multiply(a: NDBuffer<T>, k: Double): NDBuffer<T> = a.map { it * k }
override val one: NDBuffer<T> by lazy { produce { field.one } }
override fun multiply(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
override fun divide(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
} }
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, val buffer: Buffer<T>) : NDElement<T, F> { class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, override val buffer: Buffer<T>) :
NDBuffer<T>,
FieldElement<NDBuffer<T>, BufferNDElement<T, F>, BufferNDField<T, F>>,
NDElement<T, F> {
override val elementField: F get() = context.field
override fun unwrap(): NDBuffer<T> = this
override fun NDBuffer<T>.wrap(): BufferNDElement<T, F> = BufferNDElement(context, this.buffer)
override val strides get() = context.strides
override val self: NDStructure<T> get() = this
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape
override fun get(index: IntArray): T = buffer[context.strides.offset(index)] override fun get(index: IntArray): T = buffer[strides.offset(index)]
override fun elements(): Sequence<Pair<IntArray, T>> = context.strides.indices().map { it to get(it) } override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to get(it) }
override fun map(action: F.(T) -> T): BufferNDElement<T, F> = context.run { map(action) }
override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> = context.run { mapIndexed(transform) }
} }
@ -47,7 +67,7 @@ class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>
* Element by element application of any operation on elements to the whole array. Just like in numpy * Element by element application of any operation on elements to the whole array. Just like in numpy
*/ */
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDElement<T, F>) = operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDElement<T, F>) =
ndElement.context.produceBuffered { i -> invoke(ndElement.buffer[i]) } ndElement.context.run { ndElement.map { invoke(it) } }
/* plus and minus */ /* plus and minus */
@ -55,24 +75,24 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDE
* Summation operation for [BufferNDElement] and single element * Summation operation for [BufferNDElement] and single element
*/ */
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) =
context.produceBuffered { i -> buffer[i] + arg } context.run { map { it + arg } }
/** /**
* Subtraction operation between [BufferNDElement] and single element * Subtraction operation between [BufferNDElement] and single element
*/ */
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
context.produceBuffered { i -> buffer[i] - arg } context.run { map { it - arg } }
/* prod and div */ /* prod and div */
/** /**
* Product operation for [BufferNDElement] and single element * Product operation for [BufferNDElement] and single element
*/ */
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
context.produceBuffered { i -> buffer[i] * arg } context.run { map { it * arg } }
/** /**
* Division operation between [BufferNDElement] and single element * Division operation between [BufferNDElement] and single element
*/ */
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) = operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
context.produceBuffered { i -> buffer[i] / arg } context.run { map { it / arg } }

View File

@ -6,40 +6,40 @@ import scientifik.kmath.operations.PowerOperations
import scientifik.kmath.operations.TrigonometricOperations import scientifik.kmath.operations.TrigonometricOperations
interface ExtendedNDField<T : Any, F : ExtendedField<T>> : interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<out T>> :
NDField<T, F>, NDField<T, F, N>,
TrigonometricOperations<NDStructure<T>>, TrigonometricOperations<N>,
PowerOperations<NDStructure<T>>, PowerOperations<N>,
ExponentialOperations<NDStructure<T>> ExponentialOperations<N>
/** /**
* NDField that supports [ExtendedField] operations on its elements * NDField that supports [ExtendedField] operations on its elements
*/ */
inline class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>>(private val ndField: NDField<T, F>) : ExtendedNDField<T, F> { class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>, N : NDStructure<out T>>(private val ndField: NDField<T, F, N>) : ExtendedNDField<T, F, N>, NDField<T,F,N> by ndField {
override val shape: IntArray get() = ndField.shape override val shape: IntArray get() = ndField.shape
override val field: F get() = ndField.field override val field: F get() = ndField.field
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer) override fun produce(initializer: F.(IntArray) -> T) = ndField.produce(initializer)
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> { override fun power(arg: N, pow: Double): N {
return produce { with(field) { power(arg[it], pow) } } return produce { with(field) { power(arg[it], pow) } }
} }
override fun exp(arg: NDStructure<T>): NDElement<T, F> { override fun exp(arg: N): N {
return produce { with(field) { exp(arg[it]) } } return produce { with(field) { exp(arg[it]) } }
} }
override fun ln(arg: NDStructure<T>): NDElement<T, F> { override fun ln(arg: N): N {
return produce { with(field) { ln(arg[it]) } } return produce { with(field) { ln(arg[it]) } }
} }
override fun sin(arg: NDStructure<T>): NDElement<T, F> { override fun sin(arg: N): N {
return produce { with(field) { sin(arg[it]) } } return produce { with(field) { sin(arg[it]) } }
} }
override fun cos(arg: NDStructure<T>): NDElement<T, F> { override fun cos(arg: N): N {
return produce { with(field) { cos(arg[it]) } } return produce { with(field) { cos(arg[it]) } }
} }
} }

View File

@ -0,0 +1,126 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
interface NDElement<T, F : Field<T>> : NDStructure<T> {
val elementField: F
fun mapIndexed(transform: F.(index: IntArray, T) -> T): NDElement<T, F>
fun map(action: F.(T) -> T) = mapIndexed { _, value -> action(value) }
}
object NDElements {
/**
* Create a optimized NDArray of doubles
*/
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) =
NDField.real(shape).produce(initializer)
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) =
real(intArrayOf(dim)) { initializer(it[0]) }
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) =
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) =
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
/**
* Simple boxing NDArray
*/
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): GenericNDElement<T, F> {
val ndField = GenericNDField(shape, field)
val structure = ndStructure(shape) { index -> field.initializer(index) }
return GenericNDElement(ndField, structure)
}
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): GenericNDElement<T, F> {
val ndField = GenericNDField(shape, field)
val structure = ndStructure(shape, ::inlineBuffer) { index -> field.initializer(index) }
return GenericNDElement(ndField, structure)
}
}
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy
*/
operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>) = ndElement.map { value -> this@invoke(value) }
/* plus and minus */
/**
* Summation operation for [NDElements] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg + value } }
/**
* Subtraction operation between [NDElements] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg - value } }
/* prod and div */
/**
* Product operation for [NDElements] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg * value } }
/**
* Division operation between [NDElements] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg / value } }
// /**
// * Reverse sum operation
// */
// operator fun T.plus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@plus + arg[index] }
// }
//
// /**
// * Reverse minus operation
// */
// operator fun T.minus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@minus - arg[index] }
// }
//
// /**
// * Reverse product operation
// */
// operator fun T.times(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@times * arg[index] }
// }
//
// /**
// * Reverse division operation
// */
// operator fun T.div(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@div / arg[index] }
// }
/**
* Read-only [NDStructure] coupled to the context.
*/
class GenericNDElement<T, F : Field<T>>(override val context: NDField<T, F, NDStructure<T>>, private val structure: NDStructure<T>) :
NDStructure<T> by structure,
NDElement<T, F>,
FieldElement<NDStructure<T>, GenericNDElement<T, F>, NDField<T, F, NDStructure<T>>> {
override val elementField: F get() = context.field
override fun unwrap(): NDStructure<T> = structure
override fun NDStructure<T>.wrap() = GenericNDElement(context, this)
override fun mapIndexed(transform: F.(index: IntArray, T) -> T) =
ndStructure(context.shape) { index: IntArray -> context.field.transform(index, get(index)) }.wrap()
}

View File

@ -1,8 +1,6 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.DoubleField
import scientifik.kmath.operations.Field import scientifik.kmath.operations.Field
import scientifik.kmath.operations.FieldElement
/** /**
* An exception is thrown when the expected ans actual shape of NDArray differs * An exception is thrown when the expected ans actual shape of NDArray differs
@ -13,27 +11,19 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
* Field for n-dimensional arrays. * Field for n-dimensional arrays.
* @param shape - the list of dimensions of the array * @param shape - the list of dimensions of the array
* @param field - operations field defined on individual array element * @param field - operations field defined on individual array element
* @param T the type of the element contained in NDArray * @param T - the type of the element contained in ND structure
* @param F - field over structure elements
* @param R - actual nd-element type of this field
*/ */
interface NDField<T, F : Field<T>> : Field<NDStructure<T>> { interface NDField<T, F : Field<T>, N : NDStructure<out T>> : Field<N> {
val shape: IntArray val shape: IntArray
val field: F val field: F
/**
* Create new instance of NDArray using field shape and given initializer
* The producer takes list of indices as argument and returns contained value
*/
fun produce(initializer: F.(IntArray) -> T): NDElement<T, F>
override val zero: NDElement<T, F> get() = produce { zero }
override val one: NDElement<T, F> get() = produce { one }
/** /**
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field * Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
*/ */
fun checkShape(vararg elements: NDStructure<T>) { fun checkShape(vararg elements: N) {
elements.forEach { elements.forEach {
if (!shape.contentEquals(it.shape)) { if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape) throw ShapeMismatchException(shape, it.shape)
@ -41,37 +31,8 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
} }
} }
/** fun produce(initializer: F.(IntArray) -> T): N
* Element-by-element addition
*/
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
checkShape(a, b)
return produce { field.run { a[it] + b[it] } }
}
/**
* Multiply all elements by cinstant
*/
override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> {
checkShape(a)
return produce { field.run { a[it] * k } }
}
/**
* Element-by-element multiplication
*/
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
checkShape(a)
return produce { field.run { a[it] * b[it] } }
}
/**
* Element-by-element division
*/
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
checkShape(a)
return produce { field.run { a[it] / b[it] } }
}
companion object { companion object {
/** /**
@ -82,7 +43,7 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
/** /**
* Create a nd-field with boxing generic buffer * Create a nd-field with boxing generic buffer
*/ */
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F) = BufferNDField(shape, field, ::boxingBuffer) fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F) = GenericNDField(shape, field)
/** /**
* Create a most suitable implementation for nd-field using reified class * Create a most suitable implementation for nd-field using reified class
@ -92,123 +53,43 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
} }
interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F>>, NDStructure<T> { class GenericNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F, val bufferFactory: BufferFactory<T> = ::boxingBuffer) : NDField<T, F, NDStructure<T>> {
companion object { override fun produce(initializer: F.(IntArray) -> T): NDStructure<T> = ndStructure(shape, bufferFactory) { field.initializer(it) }
/**
* Create a platform-optimized NDArray of doubles
*/
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> {
return NDField.real(shape).produce(initializer)
}
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> { override val zero: NDStructure<T> by lazy { produce { zero } }
return real(intArrayOf(dim)) { initializer(it[0]) }
}
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> { override val one: NDStructure<T> by lazy { produce { one } }
return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
}
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> {
return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
}
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> { /**
// val field = NDField.real(shape) * Element-by-element addition
// return GenericNDElement(field, field.run(block)) */
// } override fun add(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> {
checkShape(a, b)
return produce { field.run { a[it] + b[it] } }
}
/** /**
* Simple boxing NDArray * Multiply all elements by cinstant
*/ */
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> { override fun multiply(a: NDStructure<T>, k: Double): NDStructure<T> {
return NDField.generic(shape, field).produce(initializer) checkShape(a)
} return produce { field.run { a[it] * k } }
}
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> { /**
return NDField.inline(shape, field).produce(initializer) * Element-by-element multiplication
} */
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> {
checkShape(a)
return produce { field.run { a[it] * b[it] } }
}
/**
* Element-by-element division
*/
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> {
checkShape(a)
return produce { field.run { a[it] / b[it] } }
} }
} }
inline fun <T, F : Field<T>> NDElement<T, F>.transformIndexed(crossinline action: F.(IntArray, T) -> T): NDElement<T, F> = context.produce { action(it, get(*it)) }
inline fun <T, F : Field<T>> NDElement<T, F>.transform(crossinline action: F.(T) -> T): NDElement<T, F> = context.produce { action(get(*it)) }
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy
*/
operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>): NDElement<T, F> = ndElement.transform { value -> this@invoke(value) }
/* plus and minus */
/**
* Summation operation for [NDElement] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = transform { value ->
context.field.run { arg + value }
}
/**
* Subtraction operation between [NDElement] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = transform { value ->
context.field.run { arg - value }
}
/* prod and div */
/**
* Product operation for [NDElement] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = transform { value ->
context.field.run { arg * value }
}
/**
* Division operation between [NDElement] and single element
*/
operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = transform { value ->
context.field.run { arg / value }
}
// /**
// * Reverse sum operation
// */
// operator fun T.plus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@plus + arg[index] }
// }
//
// /**
// * Reverse minus operation
// */
// operator fun T.minus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@minus - arg[index] }
// }
//
// /**
// * Reverse product operation
// */
// operator fun T.times(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@times * arg[index] }
// }
//
// /**
// * Reverse division operation
// */
// operator fun T.div(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
// field.run { this@div / arg[index] }
// }
class GenericNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F) : NDField<T, F> {
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = GenericNDElement(this, produceStructure(initializer))
private inline fun produceStructure(crossinline initializer: F.(IntArray) -> T): NDStructure<T> = NdStructure(shape, ::boxingBuffer) { field.initializer(it) }
}
/**
* Read-only [NDStructure] coupled to the context.
*/
class GenericNDElement<T, F : Field<T>>(override val context: NDField<T, F>, private val structure: NDStructure<T>) : NDElement<T, F>, NDStructure<T> by structure {
override val self: NDElement<T, F> get() = this
}

View File

@ -112,26 +112,24 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
} }
} }
abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> { interface NDBuffer<T> : NDStructure<T> {
abstract val buffer: B val buffer: Buffer<T>
abstract val strides: Strides val strides: Strides
override fun get(index: IntArray): T = buffer[strides.offset(index)] override fun get(index: IntArray): T = buffer[strides.offset(index)]
override val shape: IntArray override val shape: IntArray get() = strides.shape
get() = strides.shape
override fun elements() = override fun elements() = strides.indices().map { it to this[it] }
strides.indices().map { it to this[it] }
} }
/** /**
* Boxing generic [NDStructure] * Boxing generic [NDStructure]
*/ */
class BufferNDStructure<T>( data class BufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: Buffer<T> override val buffer: Buffer<T>
) : GenericNDStructure<T, Buffer<T>>() { ) : NDBuffer<T> {
init { init {
if (strides.linearSize != buffer.size) { if (strides.linearSize != buffer.size) {
@ -158,7 +156,7 @@ class BufferNDStructure<T>(
/** /**
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure] * Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
*/ */
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> { inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> {
return if (this is BufferNDStructure<T>) { return if (this is BufferNDStructure<T>) {
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
} else { } else {
@ -172,8 +170,7 @@ inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = :
* *
* Strides should be reused if possible * Strides should be reused if possible
*/ */
@Suppress("FunctionName") fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) =
fun <T : Any> NdStructure(strides: Strides, bufferFactory: BufferFactory<T>, initializer: (IntArray) -> T) =
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
/** /**
@ -182,9 +179,8 @@ fun <T : Any> NdStructure(strides: Strides, bufferFactory: BufferFactory<T>, ini
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
BufferNDStructure(strides, inlineBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) BufferNDStructure(strides, inlineBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
@Suppress("FunctionName") fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) =
fun <T : Any> NdStructure(shape: IntArray, bufferFactory: BufferFactory<T>, initializer: (IntArray) -> T) = ndStructure(DefaultStrides(shape), bufferFactory, initializer)
NdStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
inlineNDStructure(DefaultStrides(shape), initializer) inlineNDStructure(DefaultStrides(shape), initializer)
@ -195,7 +191,7 @@ inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline init
class MutableBufferNDStructure<T>( class MutableBufferNDStructure<T>(
override val strides: Strides, override val strides: Strides,
override val buffer: MutableBuffer<T> override val buffer: MutableBuffer<T>
) : GenericNDStructure<T, MutableBuffer<T>>(), MutableNDStructure<T> { ) : NDBuffer<T>, MutableNDStructure<T> {
init { init {
if (strides.linearSize != buffer.size) { if (strides.linearSize != buffer.size) {
@ -209,16 +205,14 @@ class MutableBufferNDStructure<T>(
/** /**
* The same as [inlineNDStructure], but mutable * The same as [inlineNDStructure], but mutable
*/ */
@Suppress("FunctionName") fun <T : Any> mutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer, initializer: (IntArray) -> T) =
fun <T : Any> MutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory<T>, initializer: (IntArray) -> T) =
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) }) MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
MutableBufferNDStructure(strides, inlineMutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) }) MutableBufferNDStructure(strides, inlineMutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
@Suppress("FunctionName") fun <T : Any> mutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer, initializer: (IntArray) -> T) =
fun <T : Any> MutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory<T>, initializer: (IntArray) -> T) = mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
MutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) = inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
inlineMutableNdStructure(DefaultStrides(shape), initializer) inlineMutableNdStructure(DefaultStrides(shape), initializer)

View File

@ -4,44 +4,41 @@ import scientifik.kmath.operations.DoubleField
typealias RealNDElement = BufferNDElement<Double, DoubleField> typealias RealNDElement = BufferNDElement<Double, DoubleField>
class RealNDField(shape: IntArray) : BufferNDField<Double, DoubleField>(shape, DoubleField, DoubleBufferFactory), ExtendedNDField<Double, DoubleField> { class RealNDField(shape: IntArray) :
BufferNDField<Double, DoubleField>(shape, DoubleField, DoubleBufferFactory),
ExtendedNDField<Double, DoubleField, NDBuffer<Double>> {
/** /**
* Inline map an NDStructure to * Inline map an NDStructure to
*/ */
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement { private inline fun NDBuffer<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement {
return if (this is BufferNDElement<Double, *>) { val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) } return BufferNDElement(this@RealNDField, DoubleBuffer(array))
BufferNDElement(this@RealNDField, DoubleBuffer(array))
} else {
produce { index -> DoubleField.operation(get(index)) }
}
} }
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement { override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
val array = DoubleArray(strides.linearSize) { offset -> field.initializer(strides.index(offset)) } val array = DoubleArray(strides.linearSize) { offset -> field.initializer(strides.index(offset)) }
return BufferNDElement(this, DoubleBuffer(array)) return BufferNDElement(this, DoubleBuffer(array))
} }
override fun power(arg: NDStructure<Double>, pow: Double) = arg.mapInline { power(it, pow) } override fun power(arg: NDBuffer<Double>, pow: Double) = arg.mapInline { power(it, pow) }
override fun exp(arg: NDStructure<Double>) = arg.mapInline { exp(it) } override fun exp(arg: NDBuffer<Double>) = arg.mapInline { exp(it) }
override fun ln(arg: NDStructure<Double>) = arg.mapInline { ln(it) } override fun ln(arg: NDBuffer<Double>) = arg.mapInline { ln(it) }
override fun sin(arg: NDStructure<Double>) = arg.mapInline { sin(it) } override fun sin(arg: NDBuffer<Double>) = arg.mapInline { sin(it) }
override fun cos(arg: NDStructure<Double>) = arg.mapInline { cos(it) } override fun cos(arg: NDBuffer<Double>) = arg.mapInline { cos(it) }
override fun NDStructure<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() } override fun NDBuffer<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() }
override fun NDStructure<Double>.div(k: Number) = mapInline { value -> value / k.toDouble() } override fun NDBuffer<Double>.div(k: Number) = mapInline { value -> value / k.toDouble() }
override fun Number.times(b: NDStructure<Double>) = b * this override fun Number.times(b: NDBuffer<Double>) = b * this
override fun Number.div(b: NDStructure<Double>) = b * (1.0 / this.toDouble()) override fun Number.div(b: NDBuffer<Double>) = b * (1.0 / this.toDouble())
} }
/** /**

View File

@ -7,7 +7,7 @@ import kotlin.test.assertEquals
class NDFieldTest { class NDFieldTest {
@Test @Test
fun testStrides() { fun testStrides() {
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() } val ndArray = NDElements.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0) assertEquals(ndArray[5, 5], 10.0)
} }
} }

View File

@ -1,7 +1,6 @@
package scientifik.kmath.structures package scientifik.kmath.structures
import scientifik.kmath.operations.Norm import scientifik.kmath.operations.Norm
import scientifik.kmath.structures.NDElement.Companion.real2D
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow import kotlin.math.pow
import kotlin.test.Test import kotlin.test.Test

View File

@ -5,7 +5,7 @@ import scientifik.kmath.operations.Field
class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) : BufferNDField<T, F>(shape,field, ::boxingBuffer) { class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) : BufferNDField<T, F>(shape,field, ::boxingBuffer) {
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index) val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index) val bDeferred = b.deferred(index)
@ -13,11 +13,11 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
} }
} }
override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> { override fun multiply(a: NDStructure<T>, k: Double): NDElements<T, F> {
return LazyNDStructure(this) { index -> a.await(index) * k } return LazyNDStructure(this) { index -> a.await(index) * k }
} }
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index) val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index) val bDeferred = b.deferred(index)
@ -25,7 +25,7 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
} }
} }
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> { override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
return LazyNDStructure(this) { index -> return LazyNDStructure(this) { index ->
val aDeferred = a.deferred(index) val aDeferred = a.deferred(index)
val bDeferred = b.deferred(index) val bDeferred = b.deferred(index)
@ -34,8 +34,8 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
} }
} }
class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>, val function: suspend F.(IntArray) -> T) : NDElement<T, F>, NDStructure<T> { class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>, val function: suspend F.(IntArray) -> T) : NDElements<T, F>, NDStructure<T> {
override val self: NDElement<T, F> get() = this override val self: NDElements<T, F> get() = this
override val shape: IntArray get() = context.shape override val shape: IntArray get() = context.shape
private val cache = HashMap<IntArray, Deferred<T>>() private val cache = HashMap<IntArray, Deferred<T>>()
@ -58,7 +58,7 @@ fun <T> NDStructure<T>.deferred(index: IntArray) = if (this is LazyNDStructure<T
suspend fun <T> NDStructure<T>.await(index: IntArray) = if (this is LazyNDStructure<T, *>) this.await(index) else get(index) suspend fun <T> NDStructure<T>.await(index: IntArray) = if (this is LazyNDStructure<T, *>) this.await(index) else get(index)
fun <T, F : Field<T>> NDElement<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> { fun <T, F : Field<T>> NDElements<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
return if (this is LazyNDStructure<T, F>) { return if (this is LazyNDStructure<T, F>) {
this this
} else { } else {