Update changelog, document kmath-nd4j, refactor iterators, correct algebra mistakes, separate INDArrayStructureRing to Space, Ring and Algebra

This commit is contained in:
Iaroslav 2020-08-15 18:35:16 +07:00
parent a6d47515eb
commit 7157878485
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
8 changed files with 448 additions and 114 deletions

View File

@ -18,6 +18,7 @@
- Blocking chains in `kmath-coroutines`
- Full hyperbolic functions support and default implementations within `ExtendedField`
- Norm support for `Complex`
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
### Changed
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations

View File

@ -74,9 +74,9 @@ interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement
/**
* The element of [Ring].
*
* @param T the type of space operation results.
* @param T the type of ring operation results.
* @param I self type of the element. Needed for static type checking.
* @param R the type of space.
* @param R the type of ring.
*/
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
/**
@ -91,7 +91,7 @@ interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T
/**
* The element of [Field].
*
* @param T the type of space operation results.
* @param T the type of field operation results.
* @param I self type of the element. Needed for static type checking.
* @param F the type of field.
*/

View File

@ -43,7 +43,6 @@ class BufferedNDFieldElement<T, F : Field<T>>(
override val context: BufferedNDField<T, F>,
override val buffer: Buffer<T>
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
override fun unwrap(): NDBuffer<T> = this
override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> {
@ -56,8 +55,9 @@ class BufferedNDFieldElement<T, F : Field<T>>(
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy.
*/
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(
ndElement: BufferedNDElement<T, F>
): MathElement<out BufferedNDAlgebra<T, F>> = ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
/* plus and minus */

View File

@ -5,56 +5,78 @@ import scientifik.kmath.operations.Field
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
/**
* An exception is thrown when the expected ans actual shape of NDArray differs
* An exception is thrown when the expected ans actual shape of NDArray differs.
*
* @property expected the expected shape.
* @property actual the actual shape.
*/
class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : RuntimeException()
class ShapeMismatchException(val expected: IntArray, val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
/**
* The base interface for all nd-algebra implementations
* @param T the type of nd-structure element
* @param C the type of the element context
* @param N the type of the structure
* The base interface for all ND-algebra implementations.
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
* @param N the type of the structure.
*/
interface NDAlgebra<T, C, N : NDStructure<T>> {
/**
* The shape of ND-structures this algebra operates on.
*/
val shape: IntArray
/**
* The algebra over elements of ND structure.
*/
val elementContext: C
/**
* Produce a new [N] structure using given initializer function
* Produces a new [N] structure using given initializer function.
*/
fun produce(initializer: C.(IntArray) -> T): N
/**
* Map elements from one structure to another one
* Maps elements from one structure to another one by applying [transform] to them.
*/
fun map(arg: N, transform: C.(T) -> T): N
/**
* Map indexed elements
* Maps elements from one structure to another one by applying [transform] to them alongside with their indices.
*/
fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N
/**
* Combine two structures into one
* Combines two structures into one.
*/
fun combine(a: N, b: N, transform: C.(T, T) -> T): N
/**
* Check if given elements are consistent with this context
* Checks if given element is consistent with this context.
*
* @param element the structure to check.
* @return the valid structure.
*/
fun check(vararg elements: N) {
elements.forEach {
if (!shape.contentEquals(it.shape)) {
throw ShapeMismatchException(shape, it.shape)
}
}
fun check(element: N): N {
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
return element
}
/**
* element-by-element invoke a function working on [T] on a [NDStructure]
* Checks if given elements are consistent with this context.
*
* @param elements the structures to check.
* @return the array of valid structures.
*/
fun check(vararg elements: N): Array<out N> = elements
.map(NDStructure<T>::shape)
.singleOrNull { !shape.contentEquals(it) }
?.let { throw ShapeMismatchException(shape, it) }
?: elements
/**
* Element-wise invocation of function working on [T] on a [NDStructure].
*/
operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
@ -62,43 +84,107 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
}
/**
* An nd-space over element space
* Space of [NDStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param S the type of space of structure elements.
*/
interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T, S, N> {
/**
* Element-by-element addition
* Element-wise addition.
*
* @param a the addend.
* @param b the augend.
* @return the sum.
*/
override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) }
/**
* Multiply all elements by constant
* Element-wise multiplication by scalar.
*
* @param a the multiplicand.
* @param k the multiplier.
* @return the product.
*/
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
//TODO move to extensions after KEEP-176
// TODO move to extensions after KEEP-176
/**
* Adds an ND structure to an element of it.
*
* @receiver the addend.
* @param arg the augend.
* @return the sum.
*/
operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
/**
* Subtracts an element from ND structure of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
/**
* Adds an element to ND structure of it.
*
* @receiver the addend.
* @param arg the augend.
* @return the sum.
*/
operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
/**
* Subtracts an ND structure from an element of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
companion object
}
/**
* An nd-ring over element ring
* Ring of [NDStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param R the type of ring of structure elements.
*/
interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N> {
/**
* Element-by-element multiplication
* Element-wise multiplication.
*
* @param a the multiplicand.
* @param b the multiplier.
* @return the product.
*/
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176
/**
* Multiplies an ND structure by an element of it.
*
* @receiver the multiplicand.
* @param arg the multiplier.
* @return the product.
*/
operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
/**
* Multiplies an element by a ND structure of it.
*
* @receiver the multiplicand.
* @param arg the multiplier.
* @return the product.
*/
operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
companion object
@ -109,31 +195,47 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F field of structure elements.
* @param F the type field of structure elements.
*/
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
/**
* Element-by-element division
* Element-wise division.
*
* @param a the dividend.
* @param b the divisor.
* @return the quotient.
*/
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176
/**
* Divides an ND structure by an element of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
/**
* Divides an element by an ND structure of it.
*
* @receiver the dividend.
* @param arg the divisor.
* @return the quotient.
*/
operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
companion object {
private val realNDFieldCache = HashMap<IntArray, RealNDField>()
private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
/**
* Create a nd-field for [Double] values or pull it from cache if it was created previously
* Create a nd-field for [Double] values or pull it from cache if it was created previously.
*/
fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
/**
* Create a nd-field with boxing generic buffer
* Create a ND field with boxing generic buffer.
*/
fun <T : Any, F : Field<T>> boxing(
field: F,

View File

@ -3,8 +3,8 @@
This subproject implements the following features:
- NDStructure wrapper for INDArray.
- Optimized NDRing implementation for INDArray storing Ints.
- Optimized NDField implementation for INDArray storing Floats and Doubles.
- Optimized NDRing implementations for INDArray storing Ints and Longs.
- Optimized NDField implementations for INDArray storing Floats and Doubles.
> #### Artifact:
> This module is distributed in the artifact `scientifik:kmath-nd4j:0.1.4-dev-8`.
@ -44,7 +44,7 @@ import org.nd4j.linalg.factory.*
import scientifik.kmath.nd4j.*
import scientifik.kmath.structures.*
val array = Nd4j.ones(2, 2)!!.asRealStructure()
val array = Nd4j.ones(2, 2).asRealStructure()
println(array[0, 0]) // 1.0
array[intArrayOf(0, 0)] = 24.0
println(array[0, 0]) // 24.0
@ -58,7 +58,7 @@ import scientifik.kmath.nd4j.*
import scientifik.kmath.operations.*
val field = RealINDArrayField(intArrayOf(2, 2))
val array = Nd4j.rand(2, 2)!!.asRealStructure()
val array = Nd4j.rand(2, 2).asRealStructure()
val res = field {
(25.0 / array + 20) * 4

View File

@ -3,89 +3,271 @@ package scientifik.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import scientifik.kmath.operations.*
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDField
import scientifik.kmath.structures.NDRing
interface INDArrayRing<T, R, N> :
NDRing<T, R, N> where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override val zero: N
get() = Nd4j.zeros(*shape).wrap()
override val one: N
get() = Nd4j.ones(*shape).wrap()
import scientifik.kmath.structures.*
/**
* Represents [NDAlgebra] over [INDArrayAlgebra].
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
* @param N the type of the structure.
*/
interface INDArrayAlgebra<T, C, N> : NDAlgebra<T, C, N> where N : INDArrayStructure<T>, N : MutableNDStructure<T> {
/**
* Wraps [INDArray] to [N].
*/
fun INDArray.wrap(): N
override fun produce(initializer: R.(IntArray) -> T): N {
override fun produce(initializer: C.(IntArray) -> T): N {
val struct = Nd4j.create(*shape)!!.wrap()
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
return struct
}
override fun map(arg: N, transform: R.(T) -> T): N {
override fun map(arg: N, transform: C.(T) -> T): N {
check(arg)
val newStruct = arg.ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
return newStruct
}
override fun mapIndexed(arg: N, transform: R.(index: IntArray, T) -> T): N {
override fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N {
check(arg)
val new = Nd4j.create(*shape).wrap()
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) }
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
return new
}
override fun combine(a: N, b: N, transform: R.(T, T) -> T): N {
override fun combine(a: N, b: N, transform: C.(T, T) -> T): N {
check(a, b)
val new = Nd4j.create(*shape).wrap()
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) }
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
return new
}
override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap()
override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap()
override fun N.unaryMinus(): N = ndArray.negi().wrap()
override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap()
override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap()
override fun N.div(k: Number): N = ndArray.divi(k).wrap()
override fun N.minus(b: Number): N = ndArray.subi(b).wrap()
override fun N.plus(b: Number): N = ndArray.addi(b).wrap()
override fun N.times(k: Number): N = ndArray.muli(k).wrap()
override fun Number.minus(b: N): N = b.ndArray.rsubi(this).wrap()
}
interface INDArrayField<T, F, N> : NDField<T, F, N>,
INDArrayRing<T, F, N> where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap()
override fun Number.div(b: N): N = b.ndArray.rdivi(this).wrap()
/**
* Represents [NDSpace] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param S the type of space of structure elements.
*/
interface INDArraySpace<T, S, N> : NDSpace<T, S, N>, INDArrayAlgebra<T, S, N>
where S : Space<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override val zero: N
get() = Nd4j.zeros(*shape).wrap()
override fun add(a: N, b: N): N {
check(a, b)
return a.ndArray.add(b.ndArray).wrap()
}
override operator fun N.minus(b: N): N {
check(this, b)
return ndArray.sub(b.ndArray).wrap()
}
override operator fun N.unaryMinus(): N {
check(this)
return ndArray.neg().wrap()
}
override fun multiply(a: N, k: Number): N {
check(a)
return a.ndArray.mul(k).wrap()
}
override operator fun N.div(k: Number): N {
check(this)
return ndArray.div(k).wrap()
}
override operator fun N.times(k: Number): N {
check(this)
return ndArray.mul(k).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param R the type of ring of structure elements.
*/
interface INDArrayRing<T, R, N> : NDRing<T, R, N>, INDArraySpace<T, R, N>
where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override val one: N
get() = Nd4j.ones(*shape).wrap()
override fun multiply(a: N, b: N): N {
check(a, b)
return a.ndArray.mul(b.ndArray).wrap()
}
override operator fun N.minus(b: Number): N {
check(this)
return ndArray.sub(b).wrap()
}
override operator fun N.plus(b: Number): N {
check(this)
return ndArray.add(b).wrap()
}
override operator fun Number.minus(b: N): N {
check(b)
return b.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
interface INDArrayField<T, F, N> : NDField<T, F, N>, INDArrayRing<T, F, N>
where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override fun divide(a: N, b: N): N {
check(a, b)
return a.ndArray.div(b.ndArray).wrap()
}
override operator fun Number.div(b: N): N {
check(b)
return b.ndArray.rdiv(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayRealStructure].
*/
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure()
override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap()
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap()
override fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rsubi(this).wrap()
override fun INDArray.wrap(): INDArrayRealStructure = check(asRealStructure())
override operator fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.div(arg).wrap()
}
override operator fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure {
check(arg)
return arg.ndArray.rdiv(this).wrap()
}
override operator fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayFloatStructure].
*/
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
INDArrayField<Float, Field<Float>, INDArrayFloatStructure> {
override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure()
override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap()
override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap()
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap()
override fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rsubi(this).wrap()
override fun INDArray.wrap(): INDArrayFloatStructure = check(asFloatStructure())
override operator fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.div(arg).wrap()
}
override operator fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure {
check(arg)
return arg.ndArray.rdiv(this).wrap()
}
override operator fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayIntStructure].
*/
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
INDArrayRing<Int, Ring<Int>, INDArrayIntStructure> {
override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure()
override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap()
override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap()
override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap()
override fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure = arg.ndArray.rsubi(this).wrap()
override fun INDArray.wrap(): INDArrayIntStructure = check(asIntStructure())
override operator fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayLongStructure].
*/
class LongINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Long> = LongRing) :
INDArrayRing<Long, Ring<Long>, INDArrayLongStructure> {
override fun INDArray.wrap(): INDArrayLongStructure = check(asLongStructure())
override operator fun INDArrayLongStructure.plus(arg: Long): INDArrayLongStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayLongStructure.minus(arg: Long): INDArrayLongStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayLongStructure.times(arg: Long): INDArrayLongStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Long.minus(arg: INDArrayLongStructure): INDArrayLongStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}

View File

@ -3,7 +3,24 @@ package scientifik.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.shape.Shape
internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Iterator<IntArray> {
private var i: Int = 0
override fun hasNext(): Boolean = i < iterateOver.length()
override fun next(): IntArray {
val la = if (iterateOver.ordering() == 'c')
Shape.ind2subC(iterateOver, i++.toLong())!!
else
Shape.ind2sub(iterateOver, i++.toLong())!!
return la.toIntArray()
}
}
internal fun INDArray.indicesIterator(): Iterator<IntArray> = INDArrayIndicesIterator(this)
private sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
private var i: Int = 0
final override fun hasNext(): Boolean = i < iterateOver.length()
@ -20,26 +37,26 @@ internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArra
}
}
internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
private class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
}
internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this)
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = INDArrayRealIterator(this)
internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
private class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
}
internal fun INDArray.longIterator(): INDArrayLongIterator = INDArrayLongIterator(this)
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(this)
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
private class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
}
internal fun INDArray.intIterator(): INDArrayIntIterator = INDArrayIntIterator(this)
internal fun INDArray.intIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(this)
internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
private class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
}
internal fun INDArray.floatIterator() = INDArrayFloatIterator(this)
internal fun INDArray.floatIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(this)

View File

@ -4,45 +4,77 @@ import org.nd4j.linalg.api.ndarray.INDArray
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDStructure
interface INDArrayStructure<T> : NDStructure<T> {
val ndArray: INDArray
/**
* Represents a [NDStructure] wrapping an [INDArray] object.
*
* @param T the type of items.
*/
sealed class INDArrayStructure<T> : MutableNDStructure<T> {
/**
* The wrapped [INDArray].
*/
abstract val ndArray: INDArray
override val shape: IntArray
get() = ndArray.shape().toIntArray()
fun elementsIterator(): Iterator<Pair<IntArray, T>>
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as ints.
*/
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
override fun get(index: IntArray): Int = ndArray.getInt(*index)
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayIntStructure].
*/
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as longs.
*/
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
}
/**
* Wraps this [INDArray] to [INDArrayLongStructure].
*/
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>,
MutableNDStructure<Double> {
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as reals.
*/
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayRealStructure].
*/
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>,
MutableNDStructure<Float> {
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as floats.
*/
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayFloatStructure].
*/
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)