forked from kscience/kmath
Major refactoring of types for algebra elements. Element type could differ from field type.
This commit is contained in:
parent
622a6a7756
commit
53d752dee8
@ -8,12 +8,12 @@ fun main(args: Array<String>) {
|
||||
|
||||
val n = 6000
|
||||
|
||||
val structure = NdStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
|
||||
val structure = ndStructure(intArrayOf(n, n), DoubleBufferFactory) { 1.0 }
|
||||
|
||||
structure.map { it + 1 } // warm-up
|
||||
structure.mapToBuffer { it + 1 } // warm-up
|
||||
|
||||
val time1 = measureTimeMillis {
|
||||
val res = structure.map { it + 1 }
|
||||
val res = structure.mapToBuffer { it + 1 }
|
||||
}
|
||||
println("Structure mapping finished in $time1 millis")
|
||||
|
||||
|
@ -104,7 +104,7 @@ abstract class LUDecomposition<T : Comparable<T>, F : Field<T>>(val matrix: Matr
|
||||
val m = matrix.numCols
|
||||
val pivot = IntArray(matrix.numRows)
|
||||
//TODO fix performance
|
||||
val lu: MutableNDStructure<T> = MutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] }
|
||||
val lu: MutableNDStructure<T> = mutableNdStructure(intArrayOf(matrix.numRows, matrix.numCols), ::boxingMutableBuffer) { index: IntArray -> matrix[index[0], index[1]] }
|
||||
|
||||
|
||||
with(matrix.context.ring) {
|
||||
|
@ -122,11 +122,11 @@ data class StructureMatrixSpace<T : Any, R : Ring<T>>(
|
||||
|
||||
override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix<T, R> {
|
||||
return if (rows == rowNum && columns == colNum) {
|
||||
val structure = NdStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
|
||||
val structure = ndStructure(strides, bufferFactory) { initializer(it[0], it[1]) }
|
||||
StructureMatrix(this, structure)
|
||||
} else {
|
||||
val context = StructureMatrixSpace(rows, columns, ring, bufferFactory)
|
||||
val structure = NdStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
|
||||
val structure = ndStructure(context.strides, bufferFactory) { initializer(it[0], it[1]) }
|
||||
StructureMatrix(context, structure)
|
||||
}
|
||||
}
|
||||
|
@ -16,13 +16,18 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||
|
||||
val space: S
|
||||
|
||||
fun produce(initializer: (Int) -> T): Vector<T, S>
|
||||
fun produce(initializer: (Int) -> T): Point<T>
|
||||
|
||||
override val zero: Vector<T, S> get() = produce { space.zero }
|
||||
/**
|
||||
* Produce a space-element of this vector space for expressions
|
||||
*/
|
||||
fun produceElement(initializer: (Int) -> T): Vector<T, S>
|
||||
|
||||
override fun add(a: Point<T>, b: Point<T>): Vector<T, S> = produce { with(space) { a[it] + b[it] } }
|
||||
override val zero: Point<T> get() = produce { space.zero }
|
||||
|
||||
override fun multiply(a: Point<T>, k: Double): Vector<T, S> = produce { with(space) { a[it] * k } }
|
||||
override fun add(a: Point<T>, b: Point<T>): Point<T> = produce { with(space) { a[it] + b[it] } }
|
||||
|
||||
override fun multiply(a: Point<T>, k: Double): Point<T> = produce { with(space) { a[it] * k } }
|
||||
|
||||
//TODO add basis
|
||||
|
||||
@ -53,7 +58,7 @@ interface VectorSpace<T : Any, S : Space<T>> : Space<Point<T>> {
|
||||
/**
|
||||
* A point coupled to the linear space
|
||||
*/
|
||||
interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T, S>>, Point<T> {
|
||||
interface Vector<T : Any, S : Space<T>> : SpaceElement<Vector<T,S>, VectorSpace<T, S>>, Point<T> {
|
||||
override val size: Int get() = context.size
|
||||
|
||||
override operator fun plus(b: Point<T>): Vector<T, S> = context.add(self, b)
|
||||
@ -65,11 +70,11 @@ interface Vector<T : Any, S : Space<T>> : SpaceElement<Point<T>, VectorSpace<T,
|
||||
/**
|
||||
* Create vector with custom field
|
||||
*/
|
||||
fun <T : Any, F : Space<T>> generic(size: Int, field: F, initializer: (Int) -> T) =
|
||||
VectorSpace.buffered(size, field).produce(initializer)
|
||||
fun <T : Any, S : Space<T>> generic(size: Int, field: S, initializer: (Int) -> T): Vector<T, S> =
|
||||
VectorSpace.buffered(size, field).produceElement(initializer)
|
||||
|
||||
fun real(size: Int, initializer: (Int) -> Double) = VectorSpace.real(size).produce(initializer)
|
||||
fun ofReal(vararg elements: Double) = VectorSpace.real(elements.size).produce{elements[it]}
|
||||
fun real(size: Int, initializer: (Int) -> Double): Vector<Double,DoubleField> = VectorSpace.real(size).produceElement(initializer)
|
||||
fun ofReal(vararg elements: Double): Vector<Double,DoubleField> = VectorSpace.real(elements.size).produceElement { elements[it] }
|
||||
|
||||
}
|
||||
}
|
||||
@ -79,7 +84,8 @@ data class BufferVectorSpace<T : Any, S : Space<T>>(
|
||||
override val space: S,
|
||||
val bufferFactory: BufferFactory<T>
|
||||
) : VectorSpace<T, S> {
|
||||
override fun produce(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, bufferFactory(size, initializer))
|
||||
override fun produce(initializer: (Int) -> T) = bufferFactory(size, initializer)
|
||||
override fun produceElement(initializer: (Int) -> T): Vector<T, S> = BufferVector(this, produce(initializer))
|
||||
}
|
||||
|
||||
|
||||
@ -95,7 +101,7 @@ data class BufferVector<T : Any, S : Space<T>>(override val context: VectorSpace
|
||||
return buffer[index]
|
||||
}
|
||||
|
||||
override val self: BufferVector<T, S> get() = this
|
||||
override fun getSelf(): BufferVector<T, S
|
||||
|
||||
override fun iterator(): Iterator<T> = (0 until size).map { buffer[it] }.iterator()
|
||||
|
||||
|
@ -1,23 +1,6 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
|
||||
/**
|
||||
* The generic mathematics elements which is able to store its context
|
||||
* @param T the self type of the element
|
||||
* @param S the type of mathematical context for this element
|
||||
*/
|
||||
interface MathElement<T, S> {
|
||||
/**
|
||||
* Self value. Needed for static type checking.
|
||||
*/
|
||||
val self: T
|
||||
|
||||
/**
|
||||
* The context this element belongs to
|
||||
*/
|
||||
val context: S
|
||||
}
|
||||
|
||||
/**
|
||||
* A general interface representing linear context of some kind.
|
||||
* The context defines sum operation for its elements and multiplication by real value.
|
||||
@ -53,19 +36,8 @@ interface Space<T> {
|
||||
|
||||
//TODO move to external extensions when they are available
|
||||
fun Iterable<T>.sum(): T = fold(zero) { left, right -> left + right }
|
||||
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
|
||||
}
|
||||
|
||||
/**
|
||||
* The element of linear context
|
||||
* @param T self type of the element. Needed for static type checking
|
||||
* @param S the type of space
|
||||
*/
|
||||
interface SpaceElement<T, S : Space<T>> : MathElement<T, S> {
|
||||
operator fun plus(b: T): T = context.add(self, b)
|
||||
operator fun minus(b: T): T = context.add(self, context.multiply(b, -1.0))
|
||||
operator fun times(k: Number): T = context.multiply(self, k.toDouble())
|
||||
operator fun div(k: Number): T = context.multiply(self, 1.0 / k.toDouble())
|
||||
fun Sequence<T>.sum(): T = fold(zero) { left, right -> left + right }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -86,15 +58,6 @@ interface Ring<T> : Space<T> {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring element
|
||||
*/
|
||||
interface RingElement<T, S : Ring<T>> : SpaceElement<T, S> {
|
||||
override val context: S
|
||||
|
||||
operator fun times(b: T): T = context.multiply(self, b)
|
||||
}
|
||||
|
||||
/**
|
||||
* Four operations algebra
|
||||
*/
|
||||
@ -110,12 +73,3 @@ interface Field<T> : Ring<T> {
|
||||
operator fun T.minus(b: Number) = this.minus(b * one)
|
||||
operator fun Number.minus(b: T) = -b + this
|
||||
}
|
||||
|
||||
/**
|
||||
* Field element
|
||||
*/
|
||||
interface FieldElement<T, F : Field<T>> : RingElement<T, F> {
|
||||
override val context: F
|
||||
|
||||
operator fun div(b: T): T = context.divide(self, b)
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package scientifik.kmath.operations
|
||||
|
||||
/**
|
||||
* The generic mathematics elements which is able to store its context
|
||||
* @param S the type of mathematical context for this element
|
||||
*/
|
||||
interface MathElement<S> {
|
||||
/**
|
||||
* The context this element belongs to
|
||||
*/
|
||||
val context: S
|
||||
}
|
||||
|
||||
/**
|
||||
* The element of linear context
|
||||
* @param T the type of space operation results
|
||||
* @param I self type of the element. Needed for static type checking
|
||||
* @param S the type of space
|
||||
*/
|
||||
interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement<S> {
|
||||
/**
|
||||
* Self value. Needed for static type checking.
|
||||
*/
|
||||
fun unwrap(): T
|
||||
|
||||
fun T.wrap(): I
|
||||
|
||||
operator fun plus(b: T) = context.add(unwrap(), b).wrap()
|
||||
operator fun minus(b: T) = context.add(unwrap(), context.multiply(b, -1.0)).wrap()
|
||||
operator fun times(k: Number) = context.multiply(unwrap(), k.toDouble()).wrap()
|
||||
operator fun div(k: Number) = context.multiply(unwrap(), 1.0 / k.toDouble()).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring element
|
||||
*/
|
||||
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
||||
override val context: R
|
||||
operator fun times(b: T) = context.multiply(unwrap(), b).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Field element
|
||||
*/
|
||||
interface FieldElement<T, I : FieldElement<T, I, F>, F : Field<T>> : RingElement<T, I, F> {
|
||||
override val context: F
|
||||
|
||||
operator fun div(b: T) = context.divide(unwrap(), b).wrap()
|
||||
}
|
@ -16,15 +16,18 @@ object ComplexField : Field<Complex> {
|
||||
|
||||
override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||
|
||||
override fun divide(a: Complex, b: Complex): Complex = Complex(a.re * b.re + a.im * b.im, a.re * b.im - a.im * b.re) / b.square
|
||||
override fun divide(a: Complex, b: Complex): Complex {
|
||||
val norm = b.square
|
||||
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
||||
}
|
||||
|
||||
operator fun Double.plus(c: Complex) = this.toComplex() + c
|
||||
operator fun Double.plus(c: Complex) = add(this.toComplex(), c)
|
||||
|
||||
operator fun Double.minus(c: Complex) = this.toComplex() - c
|
||||
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c)
|
||||
|
||||
operator fun Complex.plus(d: Double) = d + this
|
||||
|
||||
operator fun Complex.minus(d: Double) = this - d.toComplex()
|
||||
operator fun Complex.minus(d: Double) = add(this, -d.toComplex())
|
||||
|
||||
operator fun Double.times(c: Complex) = Complex(c.re * this, c.im * this)
|
||||
}
|
||||
@ -34,6 +37,7 @@ object ComplexField : Field<Complex> {
|
||||
*/
|
||||
data class Complex(val re: Double, val im: Double) : FieldElement<Complex, ComplexField> {
|
||||
override val self: Complex get() = this
|
||||
|
||||
override val context: ComplexField
|
||||
get() = ComplexField
|
||||
|
||||
|
@ -12,42 +12,18 @@ interface ExtendedField<T : Any> :
|
||||
ExponentialOperations<T>
|
||||
|
||||
|
||||
/**
|
||||
* Field for real values
|
||||
*/
|
||||
object RealField : ExtendedField<Real>, Norm<Real, Real> {
|
||||
override val zero: Real = Real(0.0)
|
||||
override fun add(a: Real, b: Real): Real = Real(a.value + b.value)
|
||||
override val one: Real = Real(1.0)
|
||||
override fun multiply(a: Real, b: Real): Real = Real(a.value * b.value)
|
||||
override fun multiply(a: Real, k: Double): Real = Real(a.value * k)
|
||||
override fun divide(a: Real, b: Real): Real = Real(a.value / b.value)
|
||||
|
||||
override fun sin(arg: Real): Real = Real(kotlin.math.sin(arg.value))
|
||||
override fun cos(arg: Real): Real = Real(kotlin.math.cos(arg.value))
|
||||
|
||||
override fun power(arg: Real, pow: Double): Real = Real(arg.value.pow(pow))
|
||||
|
||||
override fun exp(arg: Real): Real = Real(kotlin.math.exp(arg.value))
|
||||
|
||||
override fun ln(arg: Real): Real = Real(kotlin.math.ln(arg.value))
|
||||
|
||||
override fun norm(arg: Real): Real = Real(kotlin.math.abs(arg.value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Real field element wrapping double.
|
||||
*
|
||||
* TODO inline does not work due to compiler bug. Waiting for fix for KT-27586
|
||||
*/
|
||||
inline class Real(val value: Double) : FieldElement<Real, RealField> {
|
||||
inline class Real(val value: Double) : FieldElement<Double, Real, DoubleField> {
|
||||
override fun unwrap(): Double = value
|
||||
|
||||
//values are dynamically calculated to save memory
|
||||
override val self
|
||||
get() = this
|
||||
override fun Double.wrap(): Real = Real(value)
|
||||
|
||||
override val context
|
||||
get() = RealField
|
||||
override val context get() = DoubleField
|
||||
|
||||
companion object {
|
||||
|
||||
@ -79,11 +55,31 @@ object DoubleField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
/**
|
||||
* A field for double without boxing. Does not produce appropriate field element
|
||||
*/
|
||||
object IntField : Field<Int>{
|
||||
object IntField : Field<Int> {
|
||||
override val zero: Int = 0
|
||||
override fun add(a: Int, b: Int): Int = a + b
|
||||
override fun multiply(a: Int, b: Int): Int = a * b
|
||||
override fun multiply(a: Int, k: Double): Int = (k*a).toInt()
|
||||
override fun multiply(a: Int, k: Double): Int = (k * a).toInt()
|
||||
override val one: Int = 1
|
||||
override fun divide(a: Int, b: Int): Int = a / b
|
||||
}
|
||||
|
||||
//interface FieldAdapter<T, R> : Field<R> {
|
||||
//
|
||||
// val field: Field<T>
|
||||
//
|
||||
// abstract fun T.evolve(): R
|
||||
// abstract fun R.devolve(): T
|
||||
//
|
||||
// override val zero get() = field.zero.evolve()
|
||||
// override val one get() = field.zero.evolve()
|
||||
//
|
||||
// override fun add(a: R, b: R): R = field.add(a.devolve(), b.devolve()).evolve()
|
||||
//
|
||||
// override fun multiply(a: R, k: Double): R = field.multiply(a.devolve(), k).evolve()
|
||||
//
|
||||
//
|
||||
// override fun multiply(a: R, b: R): R = field.multiply(a.devolve(), b.devolve()).evolve()
|
||||
//
|
||||
// override fun divide(a: R, b: R): R = field.divide(a.devolve(), b.devolve()).evolve()
|
||||
//}
|
@ -1,45 +1,65 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F> {
|
||||
open class BufferNDField<T, F : Field<T>>(final override val shape: IntArray, final override val field: F, val bufferFactory: BufferFactory<T>) : NDField<T, F, NDBuffer<T>> {
|
||||
val strides = DefaultStrides(shape)
|
||||
|
||||
override fun produce(initializer: F.(IntArray) -> T): BufferNDElement<T, F> {
|
||||
return BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
|
||||
}
|
||||
override fun produce(initializer: F.(IntArray) -> T) =
|
||||
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(strides.index(offset)) })
|
||||
|
||||
open fun produceBuffered(initializer: F.(Int) -> T) =
|
||||
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.initializer(offset) })
|
||||
open fun NDBuffer<T>.map(transform: F.(T) -> T) =
|
||||
BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(buffer[offset]) })
|
||||
|
||||
// override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
// checkShape(a, b)
|
||||
// return if (a is BufferNDElement<T, *> && b is BufferNDElement<T, *>) {
|
||||
// BufferNDElement(this,bufferFactory(strides.linearSize){i-> field.run { a.buffer[i] + b.buffer[i]}})
|
||||
// } else {
|
||||
// produce { field.run { a[it] + b[it] } }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// override fun NDStructure<T>.plus(b: Number): NDElement<T,F> {
|
||||
// checkShape(this)
|
||||
// return if (this is BufferNDElement<T, *>) {
|
||||
// BufferNDElement(this@BufferNDField,bufferFactory(strides.linearSize){i-> field.run { this@plus.buffer[i] + b}})
|
||||
// } else {
|
||||
// produce {index -> field.run { this@plus[index] + b } }
|
||||
// }
|
||||
// }
|
||||
open fun NDBuffer<T>.mapIndexed(transform: F.(index: IntArray, T) -> T) =
|
||||
BufferNDElement(this@BufferNDField, bufferFactory(strides.linearSize) { offset -> field.transform(strides.index(offset), buffer[offset]) })
|
||||
|
||||
open fun combine(a: NDBuffer<T>, b: NDBuffer<T>, transform: F.(T, T) -> T) =
|
||||
BufferNDElement(this, bufferFactory(strides.linearSize) { offset -> field.transform(a[offset], b[offset]) })
|
||||
|
||||
/**
|
||||
* Convert any [NDStructure] to buffered structure using strides from this context.
|
||||
* If the structure is already [NDBuffer], conversion is free. If not, it could be expensive because iteration over indexes
|
||||
*/
|
||||
fun NDStructure<T>.toBuffer(): NDBuffer<T> =
|
||||
this as? NDBuffer<T> ?: produce { index -> get(index) }
|
||||
|
||||
override val zero: NDBuffer<T> by lazy { produce { field.zero } }
|
||||
|
||||
override fun add(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> add(aValue, bValue) }
|
||||
|
||||
override fun multiply(a: NDBuffer<T>, k: Double): NDBuffer<T> = a.map { it * k }
|
||||
|
||||
override val one: NDBuffer<T> by lazy { produce { field.one } }
|
||||
|
||||
override fun multiply(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||
|
||||
override fun divide(a: NDBuffer<T>, b: NDBuffer<T>): NDBuffer<T> = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||
}
|
||||
|
||||
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, val buffer: Buffer<T>) : NDElement<T, F> {
|
||||
class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>, override val buffer: Buffer<T>) :
|
||||
NDBuffer<T>,
|
||||
FieldElement<NDBuffer<T>, BufferNDElement<T, F>, BufferNDField<T, F>>,
|
||||
NDElement<T, F> {
|
||||
|
||||
override val elementField: F get() = context.field
|
||||
|
||||
override fun unwrap(): NDBuffer<T> = this
|
||||
|
||||
override fun NDBuffer<T>.wrap(): BufferNDElement<T, F> = BufferNDElement(context, this.buffer)
|
||||
|
||||
override val strides get() = context.strides
|
||||
|
||||
override val self: NDStructure<T> get() = this
|
||||
override val shape: IntArray get() = context.shape
|
||||
|
||||
override fun get(index: IntArray): T = buffer[context.strides.offset(index)]
|
||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = context.strides.indices().map { it to get(it) }
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { it to get(it) }
|
||||
|
||||
override fun map(action: F.(T) -> T): BufferNDElement<T, F> = context.run { map(action) }
|
||||
|
||||
override fun mapIndexed(transform: F.(index: IntArray, T) -> T): BufferNDElement<T, F> = context.run { mapIndexed(transform) }
|
||||
}
|
||||
|
||||
|
||||
@ -47,7 +67,7 @@ class BufferNDElement<T, F : Field<T>>(override val context: BufferNDField<T, F>
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
*/
|
||||
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDElement<T, F>) =
|
||||
ndElement.context.produceBuffered { i -> invoke(ndElement.buffer[i]) }
|
||||
ndElement.context.run { ndElement.map { invoke(it) } }
|
||||
|
||||
/* plus and minus */
|
||||
|
||||
@ -55,24 +75,24 @@ operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferNDE
|
||||
* Summation operation for [BufferNDElement] and single element
|
||||
*/
|
||||
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.plus(arg: T) =
|
||||
context.produceBuffered { i -> buffer[i] + arg }
|
||||
context.run { map { it + arg } }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [BufferNDElement] and single element
|
||||
*/
|
||||
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
|
||||
context.produceBuffered { i -> buffer[i] - arg }
|
||||
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.minus(arg: T) =
|
||||
context.run { map { it - arg } }
|
||||
|
||||
/* prod and div */
|
||||
|
||||
/**
|
||||
* Product operation for [BufferNDElement] and single element
|
||||
*/
|
||||
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
|
||||
context.produceBuffered { i -> buffer[i] * arg }
|
||||
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.times(arg: T) =
|
||||
context.run { map { it * arg } }
|
||||
|
||||
/**
|
||||
* Division operation between [BufferNDElement] and single element
|
||||
*/
|
||||
operator fun <T: Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
|
||||
context.produceBuffered { i -> buffer[i] / arg }
|
||||
operator fun <T : Any, F : Field<T>> BufferNDElement<T, F>.div(arg: T) =
|
||||
context.run { map { it / arg } }
|
@ -6,40 +6,40 @@ import scientifik.kmath.operations.PowerOperations
|
||||
import scientifik.kmath.operations.TrigonometricOperations
|
||||
|
||||
|
||||
interface ExtendedNDField<T : Any, F : ExtendedField<T>> :
|
||||
NDField<T, F>,
|
||||
TrigonometricOperations<NDStructure<T>>,
|
||||
PowerOperations<NDStructure<T>>,
|
||||
ExponentialOperations<NDStructure<T>>
|
||||
interface ExtendedNDField<T : Any, F : ExtendedField<T>, N : NDStructure<out T>> :
|
||||
NDField<T, F, N>,
|
||||
TrigonometricOperations<N>,
|
||||
PowerOperations<N>,
|
||||
ExponentialOperations<N>
|
||||
|
||||
|
||||
/**
|
||||
* NDField that supports [ExtendedField] operations on its elements
|
||||
*/
|
||||
inline class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>>(private val ndField: NDField<T, F>) : ExtendedNDField<T, F> {
|
||||
class ExtendedNDFieldWrapper<T : Any, F : ExtendedField<T>, N : NDStructure<out T>>(private val ndField: NDField<T, F, N>) : ExtendedNDField<T, F, N>, NDField<T,F,N> by ndField {
|
||||
|
||||
override val shape: IntArray get() = ndField.shape
|
||||
override val field: F get() = ndField.field
|
||||
|
||||
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = ndField.produce(initializer)
|
||||
override fun produce(initializer: F.(IntArray) -> T) = ndField.produce(initializer)
|
||||
|
||||
override fun power(arg: NDStructure<T>, pow: Double): NDElement<T, F> {
|
||||
override fun power(arg: N, pow: Double): N {
|
||||
return produce { with(field) { power(arg[it], pow) } }
|
||||
}
|
||||
|
||||
override fun exp(arg: NDStructure<T>): NDElement<T, F> {
|
||||
override fun exp(arg: N): N {
|
||||
return produce { with(field) { exp(arg[it]) } }
|
||||
}
|
||||
|
||||
override fun ln(arg: NDStructure<T>): NDElement<T, F> {
|
||||
override fun ln(arg: N): N {
|
||||
return produce { with(field) { ln(arg[it]) } }
|
||||
}
|
||||
|
||||
override fun sin(arg: NDStructure<T>): NDElement<T, F> {
|
||||
override fun sin(arg: N): N {
|
||||
return produce { with(field) { sin(arg[it]) } }
|
||||
}
|
||||
|
||||
override fun cos(arg: NDStructure<T>): NDElement<T, F> {
|
||||
override fun cos(arg: N): N {
|
||||
return produce { with(field) { cos(arg[it]) } }
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,126 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
|
||||
interface NDElement<T, F : Field<T>> : NDStructure<T> {
|
||||
val elementField: F
|
||||
|
||||
fun mapIndexed(transform: F.(index: IntArray, T) -> T): NDElement<T, F>
|
||||
fun map(action: F.(T) -> T) = mapIndexed { _, value -> action(value) }
|
||||
}
|
||||
|
||||
|
||||
object NDElements {
|
||||
/**
|
||||
* Create a optimized NDArray of doubles
|
||||
*/
|
||||
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }) =
|
||||
NDField.real(shape).produce(initializer)
|
||||
|
||||
|
||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }) =
|
||||
real(intArrayOf(dim)) { initializer(it[0]) }
|
||||
|
||||
|
||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }) =
|
||||
real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
|
||||
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }) =
|
||||
real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
|
||||
|
||||
/**
|
||||
* Simple boxing NDArray
|
||||
*/
|
||||
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): GenericNDElement<T, F> {
|
||||
val ndField = GenericNDField(shape, field)
|
||||
val structure = ndStructure(shape) { index -> field.initializer(index) }
|
||||
return GenericNDElement(ndField, structure)
|
||||
}
|
||||
|
||||
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): GenericNDElement<T, F> {
|
||||
val ndField = GenericNDField(shape, field)
|
||||
val structure = ndStructure(shape, ::inlineBuffer) { index -> field.initializer(index) }
|
||||
return GenericNDElement(ndField, structure)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
*/
|
||||
operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>) = ndElement.map { value -> this@invoke(value) }
|
||||
|
||||
/* plus and minus */
|
||||
|
||||
/**
|
||||
* Summation operation for [NDElements] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg + value } }
|
||||
|
||||
/**
|
||||
* Subtraction operation between [NDElements] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg - value } }
|
||||
|
||||
/* prod and div */
|
||||
|
||||
/**
|
||||
* Product operation for [NDElements] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg * value } }
|
||||
|
||||
/**
|
||||
* Division operation between [NDElements] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = this.map { value -> elementField.run { arg / value } }
|
||||
|
||||
|
||||
// /**
|
||||
// * Reverse sum operation
|
||||
// */
|
||||
// operator fun T.plus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@plus + arg[index] }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Reverse minus operation
|
||||
// */
|
||||
// operator fun T.minus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@minus - arg[index] }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Reverse product operation
|
||||
// */
|
||||
// operator fun T.times(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@times * arg[index] }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Reverse division operation
|
||||
// */
|
||||
// operator fun T.div(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@div / arg[index] }
|
||||
// }
|
||||
|
||||
|
||||
/**
|
||||
* Read-only [NDStructure] coupled to the context.
|
||||
*/
|
||||
class GenericNDElement<T, F : Field<T>>(override val context: NDField<T, F, NDStructure<T>>, private val structure: NDStructure<T>) :
|
||||
NDStructure<T> by structure,
|
||||
NDElement<T, F>,
|
||||
FieldElement<NDStructure<T>, GenericNDElement<T, F>, NDField<T, F, NDStructure<T>>> {
|
||||
override val elementField: F get() = context.field
|
||||
|
||||
override fun unwrap(): NDStructure<T> = structure
|
||||
|
||||
override fun NDStructure<T>.wrap() = GenericNDElement(context, this)
|
||||
|
||||
override fun mapIndexed(transform: F.(index: IntArray, T) -> T) =
|
||||
ndStructure(context.shape) { index: IntArray -> context.field.transform(index, get(index)) }.wrap()
|
||||
}
|
@ -1,8 +1,6 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.DoubleField
|
||||
import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.FieldElement
|
||||
|
||||
/**
|
||||
* An exception is thrown when the expected ans actual shape of NDArray differs
|
||||
@ -13,27 +11,19 @@ class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : Run
|
||||
* Field for n-dimensional arrays.
|
||||
* @param shape - the list of dimensions of the array
|
||||
* @param field - operations field defined on individual array element
|
||||
* @param T the type of the element contained in NDArray
|
||||
* @param T - the type of the element contained in ND structure
|
||||
* @param F - field over structure elements
|
||||
* @param R - actual nd-element type of this field
|
||||
*/
|
||||
interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
||||
interface NDField<T, F : Field<T>, N : NDStructure<out T>> : Field<N> {
|
||||
|
||||
val shape: IntArray
|
||||
val field: F
|
||||
|
||||
/**
|
||||
* Create new instance of NDArray using field shape and given initializer
|
||||
* The producer takes list of indices as argument and returns contained value
|
||||
*/
|
||||
fun produce(initializer: F.(IntArray) -> T): NDElement<T, F>
|
||||
|
||||
override val zero: NDElement<T, F> get() = produce { zero }
|
||||
|
||||
override val one: NDElement<T, F> get() = produce { one }
|
||||
|
||||
/**
|
||||
* Check the shape of given NDArray and throw exception if it does not coincide with shape of the field
|
||||
*/
|
||||
fun checkShape(vararg elements: NDStructure<T>) {
|
||||
fun checkShape(vararg elements: N) {
|
||||
elements.forEach {
|
||||
if (!shape.contentEquals(it.shape)) {
|
||||
throw ShapeMismatchException(shape, it.shape)
|
||||
@ -41,37 +31,8 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Element-by-element addition
|
||||
*/
|
||||
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
checkShape(a, b)
|
||||
return produce { field.run { a[it] + b[it] } }
|
||||
}
|
||||
fun produce(initializer: F.(IntArray) -> T): N
|
||||
|
||||
/**
|
||||
* Multiply all elements by cinstant
|
||||
*/
|
||||
override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> {
|
||||
checkShape(a)
|
||||
return produce { field.run { a[it] * k } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Element-by-element multiplication
|
||||
*/
|
||||
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
checkShape(a)
|
||||
return produce { field.run { a[it] * b[it] } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Element-by-element division
|
||||
*/
|
||||
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
checkShape(a)
|
||||
return produce { field.run { a[it] / b[it] } }
|
||||
}
|
||||
|
||||
companion object {
|
||||
/**
|
||||
@ -82,7 +43,7 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
||||
/**
|
||||
* Create a nd-field with boxing generic buffer
|
||||
*/
|
||||
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F) = BufferNDField(shape, field, ::boxingBuffer)
|
||||
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F) = GenericNDField(shape, field)
|
||||
|
||||
/**
|
||||
* Create a most suitable implementation for nd-field using reified class
|
||||
@ -92,123 +53,43 @@ interface NDField<T, F : Field<T>> : Field<NDStructure<T>> {
|
||||
}
|
||||
|
||||
|
||||
interface NDElement<T, F : Field<T>> : FieldElement<NDStructure<T>, NDField<T, F>>, NDStructure<T> {
|
||||
companion object {
|
||||
/**
|
||||
* Create a platform-optimized NDArray of doubles
|
||||
*/
|
||||
fun real(shape: IntArray, initializer: DoubleField.(IntArray) -> Double = { 0.0 }): NDElement<Double, DoubleField> {
|
||||
return NDField.real(shape).produce(initializer)
|
||||
}
|
||||
class GenericNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F, val bufferFactory: BufferFactory<T> = ::boxingBuffer) : NDField<T, F, NDStructure<T>> {
|
||||
override fun produce(initializer: F.(IntArray) -> T): NDStructure<T> = ndStructure(shape, bufferFactory) { field.initializer(it) }
|
||||
|
||||
fun real1D(dim: Int, initializer: (Int) -> Double = { _ -> 0.0 }): NDElement<Double, DoubleField> {
|
||||
return real(intArrayOf(dim)) { initializer(it[0]) }
|
||||
}
|
||||
override val zero: NDStructure<T> by lazy { produce { zero } }
|
||||
|
||||
fun real2D(dim1: Int, dim2: Int, initializer: (Int, Int) -> Double = { _, _ -> 0.0 }): NDElement<Double, DoubleField> {
|
||||
return real(intArrayOf(dim1, dim2)) { initializer(it[0], it[1]) }
|
||||
}
|
||||
override val one: NDStructure<T> by lazy { produce { one } }
|
||||
|
||||
fun real3D(dim1: Int, dim2: Int, dim3: Int, initializer: (Int, Int, Int) -> Double = { _, _, _ -> 0.0 }): NDElement<Double, DoubleField> {
|
||||
return real(intArrayOf(dim1, dim2, dim3)) { initializer(it[0], it[1], it[2]) }
|
||||
}
|
||||
|
||||
// inline fun real(shape: IntArray, block: ExtendedNDField<Double, DoubleField>.() -> NDStructure<Double>): NDElement<Double, DoubleField> {
|
||||
// val field = NDField.real(shape)
|
||||
// return GenericNDElement(field, field.run(block))
|
||||
// }
|
||||
|
||||
/**
|
||||
* Simple boxing NDArray
|
||||
* Element-by-element addition
|
||||
*/
|
||||
fun <T : Any, F : Field<T>> generic(shape: IntArray, field: F, initializer: F.(IntArray) -> T): NDElement<T, F> {
|
||||
return NDField.generic(shape, field).produce(initializer)
|
||||
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> {
|
||||
checkShape(a, b)
|
||||
return produce { field.run { a[it] + b[it] } }
|
||||
}
|
||||
|
||||
inline fun <reified T : Any, F : Field<T>> inline(shape: IntArray, field: F, noinline initializer: F.(IntArray) -> T): NDElement<T, F> {
|
||||
return NDField.inline(shape, field).produce(initializer)
|
||||
/**
|
||||
* Multiply all elements by cinstant
|
||||
*/
|
||||
override fun multiply(a: NDStructure<T>, k: Double): NDStructure<T> {
|
||||
checkShape(a)
|
||||
return produce { field.run { a[it] * k } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Element-by-element multiplication
|
||||
*/
|
||||
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> {
|
||||
checkShape(a)
|
||||
return produce { field.run { a[it] * b[it] } }
|
||||
}
|
||||
|
||||
/**
|
||||
* Element-by-element division
|
||||
*/
|
||||
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDStructure<T> {
|
||||
checkShape(a)
|
||||
return produce { field.run { a[it] / b[it] } }
|
||||
}
|
||||
}
|
||||
|
||||
inline fun <T, F : Field<T>> NDElement<T, F>.transformIndexed(crossinline action: F.(IntArray, T) -> T): NDElement<T, F> = context.produce { action(it, get(*it)) }
|
||||
inline fun <T, F : Field<T>> NDElement<T, F>.transform(crossinline action: F.(T) -> T): NDElement<T, F> = context.produce { action(get(*it)) }
|
||||
|
||||
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||
*/
|
||||
operator fun <T, F : Field<T>> Function1<T, T>.invoke(ndElement: NDElement<T, F>): NDElement<T, F> = ndElement.transform { value -> this@invoke(value) }
|
||||
|
||||
/* plus and minus */
|
||||
|
||||
/**
|
||||
* Summation operation for [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.plus(arg: T): NDElement<T, F> = transform { value ->
|
||||
context.field.run { arg + value }
|
||||
}
|
||||
|
||||
/**
|
||||
* Subtraction operation between [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.minus(arg: T): NDElement<T, F> = transform { value ->
|
||||
context.field.run { arg - value }
|
||||
}
|
||||
|
||||
/* prod and div */
|
||||
|
||||
/**
|
||||
* Product operation for [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.times(arg: T): NDElement<T, F> = transform { value ->
|
||||
context.field.run { arg * value }
|
||||
}
|
||||
|
||||
/**
|
||||
* Division operation between [NDElement] and single element
|
||||
*/
|
||||
operator fun <T, F : Field<T>> NDElement<T, F>.div(arg: T): NDElement<T, F> = transform { value ->
|
||||
context.field.run { arg / value }
|
||||
}
|
||||
|
||||
|
||||
// /**
|
||||
// * Reverse sum operation
|
||||
// */
|
||||
// operator fun T.plus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@plus + arg[index] }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Reverse minus operation
|
||||
// */
|
||||
// operator fun T.minus(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@minus - arg[index] }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Reverse product operation
|
||||
// */
|
||||
// operator fun T.times(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@times * arg[index] }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Reverse division operation
|
||||
// */
|
||||
// operator fun T.div(arg: NDStructure<T>): NDElement<T, F> = produce { index ->
|
||||
// field.run { this@div / arg[index] }
|
||||
// }
|
||||
|
||||
class GenericNDField<T : Any, F : Field<T>>(override val shape: IntArray, override val field: F) : NDField<T, F> {
|
||||
override fun produce(initializer: F.(IntArray) -> T): NDElement<T, F> = GenericNDElement(this, produceStructure(initializer))
|
||||
private inline fun produceStructure(crossinline initializer: F.(IntArray) -> T): NDStructure<T> = NdStructure(shape, ::boxingBuffer) { field.initializer(it) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Read-only [NDStructure] coupled to the context.
|
||||
*/
|
||||
class GenericNDElement<T, F : Field<T>>(override val context: NDField<T, F>, private val structure: NDStructure<T>) : NDElement<T, F>, NDStructure<T> by structure {
|
||||
override val self: NDElement<T, F> get() = this
|
||||
}
|
||||
|
@ -112,26 +112,24 @@ class DefaultStrides private constructor(override val shape: IntArray) : Strides
|
||||
}
|
||||
}
|
||||
|
||||
abstract class GenericNDStructure<T, B : Buffer<T>> : NDStructure<T> {
|
||||
abstract val buffer: B
|
||||
abstract val strides: Strides
|
||||
interface NDBuffer<T> : NDStructure<T> {
|
||||
val buffer: Buffer<T>
|
||||
val strides: Strides
|
||||
|
||||
override fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||
|
||||
override val shape: IntArray
|
||||
get() = strides.shape
|
||||
override val shape: IntArray get() = strides.shape
|
||||
|
||||
override fun elements() =
|
||||
strides.indices().map { it to this[it] }
|
||||
override fun elements() = strides.indices().map { it to this[it] }
|
||||
}
|
||||
|
||||
/**
|
||||
* Boxing generic [NDStructure]
|
||||
*/
|
||||
class BufferNDStructure<T>(
|
||||
data class BufferNDStructure<T>(
|
||||
override val strides: Strides,
|
||||
override val buffer: Buffer<T>
|
||||
) : GenericNDStructure<T, Buffer<T>>() {
|
||||
) : NDBuffer<T> {
|
||||
|
||||
init {
|
||||
if (strides.linearSize != buffer.size) {
|
||||
@ -158,7 +156,7 @@ class BufferNDStructure<T>(
|
||||
/**
|
||||
* Transform structure to a new structure using provided [BufferFactory] and optimizing if argument is [BufferNDStructure]
|
||||
*/
|
||||
inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> {
|
||||
inline fun <T, reified R : Any> NDStructure<T>.mapToBuffer(factory: BufferFactory<R> = ::inlineBuffer, crossinline transform: (T) -> R): BufferNDStructure<R> {
|
||||
return if (this is BufferNDStructure<T>) {
|
||||
BufferNDStructure(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
||||
} else {
|
||||
@ -172,8 +170,7 @@ inline fun <T, reified R : Any> NDStructure<T>.map(factory: BufferFactory<R> = :
|
||||
*
|
||||
* Strides should be reused if possible
|
||||
*/
|
||||
@Suppress("FunctionName")
|
||||
fun <T : Any> NdStructure(strides: Strides, bufferFactory: BufferFactory<T>, initializer: (IntArray) -> T) =
|
||||
fun <T> ndStructure(strides: Strides, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
/**
|
||||
@ -182,9 +179,8 @@ fun <T : Any> NdStructure(strides: Strides, bufferFactory: BufferFactory<T>, ini
|
||||
inline fun <reified T : Any> inlineNDStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
BufferNDStructure(strides, inlineBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
@Suppress("FunctionName")
|
||||
fun <T : Any> NdStructure(shape: IntArray, bufferFactory: BufferFactory<T>, initializer: (IntArray) -> T) =
|
||||
NdStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
fun <T> ndStructure(shape: IntArray, bufferFactory: BufferFactory<T> = ::boxingBuffer, initializer: (IntArray) -> T) =
|
||||
ndStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
inlineNDStructure(DefaultStrides(shape), initializer)
|
||||
@ -195,7 +191,7 @@ inline fun <reified T : Any> inlineNdStructure(shape: IntArray, crossinline init
|
||||
class MutableBufferNDStructure<T>(
|
||||
override val strides: Strides,
|
||||
override val buffer: MutableBuffer<T>
|
||||
) : GenericNDStructure<T, MutableBuffer<T>>(), MutableNDStructure<T> {
|
||||
) : NDBuffer<T>, MutableNDStructure<T> {
|
||||
|
||||
init {
|
||||
if (strides.linearSize != buffer.size) {
|
||||
@ -209,16 +205,14 @@ class MutableBufferNDStructure<T>(
|
||||
/**
|
||||
* The same as [inlineNDStructure], but mutable
|
||||
*/
|
||||
@Suppress("FunctionName")
|
||||
fun <T : Any> MutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory<T>, initializer: (IntArray) -> T) =
|
||||
fun <T : Any> mutableNdStructure(strides: Strides, bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer, initializer: (IntArray) -> T) =
|
||||
MutableBufferNDStructure(strides, bufferFactory(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
inline fun <reified T : Any> inlineMutableNdStructure(strides: Strides, crossinline initializer: (IntArray) -> T) =
|
||||
MutableBufferNDStructure(strides, inlineMutableBuffer(strides.linearSize) { i -> initializer(strides.index(i)) })
|
||||
|
||||
@Suppress("FunctionName")
|
||||
fun <T : Any> MutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory<T>, initializer: (IntArray) -> T) =
|
||||
MutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
fun <T : Any> mutableNdStructure(shape: IntArray, bufferFactory: MutableBufferFactory<T> = ::boxingMutableBuffer, initializer: (IntArray) -> T) =
|
||||
mutableNdStructure(DefaultStrides(shape), bufferFactory, initializer)
|
||||
|
||||
inline fun <reified T : Any> inlineMutableNdStructure(shape: IntArray, crossinline initializer: (IntArray) -> T) =
|
||||
inlineMutableNdStructure(DefaultStrides(shape), initializer)
|
||||
|
@ -4,20 +4,17 @@ import scientifik.kmath.operations.DoubleField
|
||||
|
||||
typealias RealNDElement = BufferNDElement<Double, DoubleField>
|
||||
|
||||
class RealNDField(shape: IntArray) : BufferNDField<Double, DoubleField>(shape, DoubleField, DoubleBufferFactory), ExtendedNDField<Double, DoubleField> {
|
||||
class RealNDField(shape: IntArray) :
|
||||
BufferNDField<Double, DoubleField>(shape, DoubleField, DoubleBufferFactory),
|
||||
ExtendedNDField<Double, DoubleField, NDBuffer<Double>> {
|
||||
|
||||
/**
|
||||
* Inline map an NDStructure to
|
||||
*/
|
||||
private inline fun NDStructure<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement {
|
||||
return if (this is BufferNDElement<Double, *>) {
|
||||
private inline fun NDBuffer<Double>.mapInline(crossinline operation: DoubleField.(Double) -> Double): RealNDElement {
|
||||
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.operation(buffer[offset]) }
|
||||
BufferNDElement(this@RealNDField, DoubleBuffer(array))
|
||||
} else {
|
||||
produce { index -> DoubleField.operation(get(index)) }
|
||||
return BufferNDElement(this@RealNDField, DoubleBuffer(array))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
|
||||
@ -25,23 +22,23 @@ class RealNDField(shape: IntArray) : BufferNDField<Double, DoubleField>(shape, D
|
||||
return BufferNDElement(this, DoubleBuffer(array))
|
||||
}
|
||||
|
||||
override fun power(arg: NDStructure<Double>, pow: Double) = arg.mapInline { power(it, pow) }
|
||||
override fun power(arg: NDBuffer<Double>, pow: Double) = arg.mapInline { power(it, pow) }
|
||||
|
||||
override fun exp(arg: NDStructure<Double>) = arg.mapInline { exp(it) }
|
||||
override fun exp(arg: NDBuffer<Double>) = arg.mapInline { exp(it) }
|
||||
|
||||
override fun ln(arg: NDStructure<Double>) = arg.mapInline { ln(it) }
|
||||
override fun ln(arg: NDBuffer<Double>) = arg.mapInline { ln(it) }
|
||||
|
||||
override fun sin(arg: NDStructure<Double>) = arg.mapInline { sin(it) }
|
||||
override fun sin(arg: NDBuffer<Double>) = arg.mapInline { sin(it) }
|
||||
|
||||
override fun cos(arg: NDStructure<Double>) = arg.mapInline { cos(it) }
|
||||
override fun cos(arg: NDBuffer<Double>) = arg.mapInline { cos(it) }
|
||||
|
||||
override fun NDStructure<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() }
|
||||
override fun NDBuffer<Double>.times(k: Number) = mapInline { value -> value * k.toDouble() }
|
||||
|
||||
override fun NDStructure<Double>.div(k: Number) = mapInline { value -> value / k.toDouble() }
|
||||
override fun NDBuffer<Double>.div(k: Number) = mapInline { value -> value / k.toDouble() }
|
||||
|
||||
override fun Number.times(b: NDStructure<Double>) = b * this
|
||||
override fun Number.times(b: NDBuffer<Double>) = b * this
|
||||
|
||||
override fun Number.div(b: NDStructure<Double>) = b * (1.0 / this.toDouble())
|
||||
override fun Number.div(b: NDBuffer<Double>) = b * (1.0 / this.toDouble())
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -7,7 +7,7 @@ import kotlin.test.assertEquals
|
||||
class NDFieldTest {
|
||||
@Test
|
||||
fun testStrides() {
|
||||
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
|
||||
val ndArray = NDElements.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
|
||||
assertEquals(ndArray[5, 5], 10.0)
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package scientifik.kmath.structures
|
||||
|
||||
import scientifik.kmath.operations.Norm
|
||||
import scientifik.kmath.structures.NDElement.Companion.real2D
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
import kotlin.test.Test
|
||||
|
@ -5,7 +5,7 @@ import scientifik.kmath.operations.Field
|
||||
|
||||
class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: CoroutineScope = GlobalScope) : BufferNDField<T, F>(shape,field, ::boxingBuffer) {
|
||||
|
||||
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
override fun add(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
|
||||
return LazyNDStructure(this) { index ->
|
||||
val aDeferred = a.deferred(index)
|
||||
val bDeferred = b.deferred(index)
|
||||
@ -13,11 +13,11 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
|
||||
}
|
||||
}
|
||||
|
||||
override fun multiply(a: NDStructure<T>, k: Double): NDElement<T, F> {
|
||||
override fun multiply(a: NDStructure<T>, k: Double): NDElements<T, F> {
|
||||
return LazyNDStructure(this) { index -> a.await(index) * k }
|
||||
}
|
||||
|
||||
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
override fun multiply(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
|
||||
return LazyNDStructure(this) { index ->
|
||||
val aDeferred = a.deferred(index)
|
||||
val bDeferred = b.deferred(index)
|
||||
@ -25,7 +25,7 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
|
||||
}
|
||||
}
|
||||
|
||||
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElement<T, F> {
|
||||
override fun divide(a: NDStructure<T>, b: NDStructure<T>): NDElements<T, F> {
|
||||
return LazyNDStructure(this) { index ->
|
||||
val aDeferred = a.deferred(index)
|
||||
val bDeferred = b.deferred(index)
|
||||
@ -34,8 +34,8 @@ class LazyNDField<T, F : Field<T>>(shape: IntArray, field: F, val scope: Corouti
|
||||
}
|
||||
}
|
||||
|
||||
class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>, val function: suspend F.(IntArray) -> T) : NDElement<T, F>, NDStructure<T> {
|
||||
override val self: NDElement<T, F> get() = this
|
||||
class LazyNDStructure<T, F : Field<T>>(override val context: LazyNDField<T, F>, val function: suspend F.(IntArray) -> T) : NDElements<T, F>, NDStructure<T> {
|
||||
override val self: NDElements<T, F> get() = this
|
||||
override val shape: IntArray get() = context.shape
|
||||
|
||||
private val cache = HashMap<IntArray, Deferred<T>>()
|
||||
@ -58,7 +58,7 @@ fun <T> NDStructure<T>.deferred(index: IntArray) = if (this is LazyNDStructure<T
|
||||
|
||||
suspend fun <T> NDStructure<T>.await(index: IntArray) = if (this is LazyNDStructure<T, *>) this.await(index) else get(index)
|
||||
|
||||
fun <T, F : Field<T>> NDElement<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
|
||||
fun <T, F : Field<T>> NDElements<T, F>.lazy(scope: CoroutineScope = GlobalScope): LazyNDStructure<T, F> {
|
||||
return if (this is LazyNDStructure<T, F>) {
|
||||
this
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user