forked from kscience/kmath
Update changelog, document kmath-nd4j, refactor iterators, correct algebra mistakes, separate INDArrayStructureRing to Space, Ring and Algebra
This commit is contained in:
parent
a6d47515eb
commit
7157878485
@ -18,6 +18,7 @@
|
||||
- Blocking chains in `kmath-coroutines`
|
||||
- Full hyperbolic functions support and default implementations within `ExtendedField`
|
||||
- Norm support for `Complex`
|
||||
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
|
||||
|
||||
### Changed
|
||||
- BigInteger and BigDecimal algebra: JBigDecimalField has companion object with default math context; minor optimizations
|
||||
|
@ -74,9 +74,9 @@ interface SpaceElement<T, I : SpaceElement<T, I, S>, S : Space<T>> : MathElement
|
||||
/**
|
||||
* The element of [Ring].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param T the type of ring operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param R the type of space.
|
||||
* @param R the type of ring.
|
||||
*/
|
||||
interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T, I, R> {
|
||||
/**
|
||||
@ -91,7 +91,7 @@ interface RingElement<T, I : RingElement<T, I, R>, R : Ring<T>> : SpaceElement<T
|
||||
/**
|
||||
* The element of [Field].
|
||||
*
|
||||
* @param T the type of space operation results.
|
||||
* @param T the type of field operation results.
|
||||
* @param I self type of the element. Needed for static type checking.
|
||||
* @param F the type of field.
|
||||
*/
|
||||
|
@ -43,7 +43,6 @@ class BufferedNDFieldElement<T, F : Field<T>>(
|
||||
override val context: BufferedNDField<T, F>,
|
||||
override val buffer: Buffer<T>
|
||||
) : BufferedNDElement<T, F>(), FieldElement<NDBuffer<T>, BufferedNDFieldElement<T, F>, BufferedNDField<T, F>> {
|
||||
|
||||
override fun unwrap(): NDBuffer<T> = this
|
||||
|
||||
override fun NDBuffer<T>.wrap(): BufferedNDFieldElement<T, F> {
|
||||
@ -56,8 +55,9 @@ class BufferedNDFieldElement<T, F : Field<T>>(
|
||||
/**
|
||||
* Element by element application of any operation on elements to the whole array. Just like in numpy.
|
||||
*/
|
||||
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(ndElement: BufferedNDElement<T, F>): MathElement<out BufferedNDAlgebra<T, F>> =
|
||||
ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
|
||||
operator fun <T : Any, F : Field<T>> Function1<T, T>.invoke(
|
||||
ndElement: BufferedNDElement<T, F>
|
||||
): MathElement<out BufferedNDAlgebra<T, F>> = ndElement.context.run { map(ndElement) { invoke(it) }.toElement() }
|
||||
|
||||
/* plus and minus */
|
||||
|
||||
|
@ -5,56 +5,78 @@ import scientifik.kmath.operations.Field
|
||||
import scientifik.kmath.operations.Ring
|
||||
import scientifik.kmath.operations.Space
|
||||
|
||||
|
||||
/**
|
||||
* An exception is thrown when the expected ans actual shape of NDArray differs
|
||||
* An exception is thrown when the expected ans actual shape of NDArray differs.
|
||||
*
|
||||
* @property expected the expected shape.
|
||||
* @property actual the actual shape.
|
||||
*/
|
||||
class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : RuntimeException()
|
||||
|
||||
class ShapeMismatchException(val expected: IntArray, val actual: IntArray) :
|
||||
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||
|
||||
/**
|
||||
* The base interface for all nd-algebra implementations
|
||||
* @param T the type of nd-structure element
|
||||
* @param C the type of the element context
|
||||
* @param N the type of the structure
|
||||
* The base interface for all ND-algebra implementations.
|
||||
*
|
||||
* @param T the type of ND-structure element.
|
||||
* @param C the type of the element context.
|
||||
* @param N the type of the structure.
|
||||
*/
|
||||
interface NDAlgebra<T, C, N : NDStructure<T>> {
|
||||
/**
|
||||
* The shape of ND-structures this algebra operates on.
|
||||
*/
|
||||
val shape: IntArray
|
||||
|
||||
/**
|
||||
* The algebra over elements of ND structure.
|
||||
*/
|
||||
val elementContext: C
|
||||
|
||||
/**
|
||||
* Produce a new [N] structure using given initializer function
|
||||
* Produces a new [N] structure using given initializer function.
|
||||
*/
|
||||
fun produce(initializer: C.(IntArray) -> T): N
|
||||
|
||||
/**
|
||||
* Map elements from one structure to another one
|
||||
* Maps elements from one structure to another one by applying [transform] to them.
|
||||
*/
|
||||
fun map(arg: N, transform: C.(T) -> T): N
|
||||
|
||||
/**
|
||||
* Map indexed elements
|
||||
* Maps elements from one structure to another one by applying [transform] to them alongside with their indices.
|
||||
*/
|
||||
fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N
|
||||
|
||||
/**
|
||||
* Combine two structures into one
|
||||
* Combines two structures into one.
|
||||
*/
|
||||
fun combine(a: N, b: N, transform: C.(T, T) -> T): N
|
||||
|
||||
/**
|
||||
* Check if given elements are consistent with this context
|
||||
* Checks if given element is consistent with this context.
|
||||
*
|
||||
* @param element the structure to check.
|
||||
* @return the valid structure.
|
||||
*/
|
||||
fun check(vararg elements: N) {
|
||||
elements.forEach {
|
||||
if (!shape.contentEquals(it.shape)) {
|
||||
throw ShapeMismatchException(shape, it.shape)
|
||||
}
|
||||
}
|
||||
fun check(element: N): N {
|
||||
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
|
||||
return element
|
||||
}
|
||||
|
||||
/**
|
||||
* element-by-element invoke a function working on [T] on a [NDStructure]
|
||||
* Checks if given elements are consistent with this context.
|
||||
*
|
||||
* @param elements the structures to check.
|
||||
* @return the array of valid structures.
|
||||
*/
|
||||
fun check(vararg elements: N): Array<out N> = elements
|
||||
.map(NDStructure<T>::shape)
|
||||
.singleOrNull { !shape.contentEquals(it) }
|
||||
?.let { throw ShapeMismatchException(shape, it) }
|
||||
?: elements
|
||||
|
||||
/**
|
||||
* Element-wise invocation of function working on [T] on a [NDStructure].
|
||||
*/
|
||||
operator fun Function1<T, T>.invoke(structure: N): N = map(structure) { value -> this@invoke(value) }
|
||||
|
||||
@ -62,43 +84,107 @@ interface NDAlgebra<T, C, N : NDStructure<T>> {
|
||||
}
|
||||
|
||||
/**
|
||||
* An nd-space over element space
|
||||
* Space of [NDStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
interface NDSpace<T, S : Space<T>, N : NDStructure<T>> : Space<N>, NDAlgebra<T, S, N> {
|
||||
/**
|
||||
* Element-by-element addition
|
||||
* Element-wise addition.
|
||||
*
|
||||
* @param a the addend.
|
||||
* @param b the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
override fun add(a: N, b: N): N = combine(a, b) { aValue, bValue -> add(aValue, bValue) }
|
||||
|
||||
/**
|
||||
* Multiply all elements by constant
|
||||
* Element-wise multiplication by scalar.
|
||||
*
|
||||
* @param a the multiplicand.
|
||||
* @param k the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
override fun multiply(a: N, k: Number): N = map(a) { multiply(it, k) }
|
||||
|
||||
// TODO move to extensions after KEEP-176
|
||||
|
||||
/**
|
||||
* Adds an ND structure to an element of it.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param arg the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
operator fun N.plus(arg: T): N = map(this) { value -> add(arg, value) }
|
||||
|
||||
/**
|
||||
* Subtracts an element from ND structure of it.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param arg the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun N.minus(arg: T): N = map(this) { value -> add(arg, -value) }
|
||||
|
||||
/**
|
||||
* Adds an element to ND structure of it.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param arg the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
operator fun T.plus(arg: N): N = map(arg) { value -> add(this@plus, value) }
|
||||
|
||||
/**
|
||||
* Subtracts an ND structure from an element of it.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param arg the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun T.minus(arg: N): N = map(arg) { value -> add(-this@minus, value) }
|
||||
|
||||
companion object
|
||||
}
|
||||
|
||||
/**
|
||||
* An nd-ring over element ring
|
||||
* Ring of [NDStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N> {
|
||||
|
||||
/**
|
||||
* Element-by-element multiplication
|
||||
* Element-wise multiplication.
|
||||
*
|
||||
* @param a the multiplicand.
|
||||
* @param b the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
override fun multiply(a: N, b: N): N = combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
|
||||
/**
|
||||
* Multiplies an ND structure by an element of it.
|
||||
*
|
||||
* @receiver the multiplicand.
|
||||
* @param arg the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
operator fun N.times(arg: T): N = map(this) { value -> multiply(arg, value) }
|
||||
|
||||
/**
|
||||
* Multiplies an element by a ND structure of it.
|
||||
*
|
||||
* @receiver the multiplicand.
|
||||
* @param arg the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
operator fun T.times(arg: N): N = map(arg) { value -> multiply(this@times, value) }
|
||||
|
||||
companion object
|
||||
@ -109,31 +195,47 @@ interface NDRing<T, R : Ring<T>, N : NDStructure<T>> : Ring<N>, NDSpace<T, R, N>
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F field of structure elements.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
interface NDField<T, F : Field<T>, N : NDStructure<T>> : Field<N>, NDRing<T, F, N> {
|
||||
|
||||
/**
|
||||
* Element-by-element division
|
||||
* Element-wise division.
|
||||
*
|
||||
* @param a the dividend.
|
||||
* @param b the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
override fun divide(a: N, b: N): N = combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
/**
|
||||
* Divides an ND structure by an element of it.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param arg the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun N.div(arg: T): N = map(this) { value -> divide(arg, value) }
|
||||
|
||||
/**
|
||||
* Divides an element by an ND structure of it.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param arg the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
operator fun T.div(arg: N): N = map(arg) { divide(it, this@div) }
|
||||
|
||||
companion object {
|
||||
|
||||
private val realNDFieldCache = HashMap<IntArray, RealNDField>()
|
||||
private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
|
||||
|
||||
/**
|
||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously
|
||||
* Create a nd-field for [Double] values or pull it from cache if it was created previously.
|
||||
*/
|
||||
fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||
|
||||
/**
|
||||
* Create a nd-field with boxing generic buffer
|
||||
* Create a ND field with boxing generic buffer.
|
||||
*/
|
||||
fun <T : Any, F : Field<T>> boxing(
|
||||
field: F,
|
||||
|
@ -3,8 +3,8 @@
|
||||
This subproject implements the following features:
|
||||
|
||||
- NDStructure wrapper for INDArray.
|
||||
- Optimized NDRing implementation for INDArray storing Ints.
|
||||
- Optimized NDField implementation for INDArray storing Floats and Doubles.
|
||||
- Optimized NDRing implementations for INDArray storing Ints and Longs.
|
||||
- Optimized NDField implementations for INDArray storing Floats and Doubles.
|
||||
|
||||
> #### Artifact:
|
||||
> This module is distributed in the artifact `scientifik:kmath-nd4j:0.1.4-dev-8`.
|
||||
@ -44,7 +44,7 @@ import org.nd4j.linalg.factory.*
|
||||
import scientifik.kmath.nd4j.*
|
||||
import scientifik.kmath.structures.*
|
||||
|
||||
val array = Nd4j.ones(2, 2)!!.asRealStructure()
|
||||
val array = Nd4j.ones(2, 2).asRealStructure()
|
||||
println(array[0, 0]) // 1.0
|
||||
array[intArrayOf(0, 0)] = 24.0
|
||||
println(array[0, 0]) // 24.0
|
||||
@ -58,7 +58,7 @@ import scientifik.kmath.nd4j.*
|
||||
import scientifik.kmath.operations.*
|
||||
|
||||
val field = RealINDArrayField(intArrayOf(2, 2))
|
||||
val array = Nd4j.rand(2, 2)!!.asRealStructure()
|
||||
val array = Nd4j.rand(2, 2).asRealStructure()
|
||||
|
||||
val res = field {
|
||||
(25.0 / array + 20) * 4
|
||||
|
@ -3,89 +3,271 @@ package scientifik.kmath.nd4j
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import scientifik.kmath.operations.*
|
||||
import scientifik.kmath.structures.MutableNDStructure
|
||||
import scientifik.kmath.structures.NDField
|
||||
import scientifik.kmath.structures.NDRing
|
||||
|
||||
interface INDArrayRing<T, R, N> :
|
||||
NDRing<T, R, N> where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
override val zero: N
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
|
||||
override val one: N
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
import scientifik.kmath.structures.*
|
||||
|
||||
/**
|
||||
* Represents [NDAlgebra] over [INDArrayAlgebra].
|
||||
*
|
||||
* @param T the type of ND-structure element.
|
||||
* @param C the type of the element context.
|
||||
* @param N the type of the structure.
|
||||
*/
|
||||
interface INDArrayAlgebra<T, C, N> : NDAlgebra<T, C, N> where N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
/**
|
||||
* Wraps [INDArray] to [N].
|
||||
*/
|
||||
fun INDArray.wrap(): N
|
||||
|
||||
override fun produce(initializer: R.(IntArray) -> T): N {
|
||||
override fun produce(initializer: C.(IntArray) -> T): N {
|
||||
val struct = Nd4j.create(*shape)!!.wrap()
|
||||
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
|
||||
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
||||
return struct
|
||||
}
|
||||
|
||||
override fun map(arg: N, transform: R.(T) -> T): N {
|
||||
override fun map(arg: N, transform: C.(T) -> T): N {
|
||||
check(arg)
|
||||
val newStruct = arg.ndArray.dup().wrap()
|
||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
||||
return newStruct
|
||||
}
|
||||
|
||||
override fun mapIndexed(arg: N, transform: R.(index: IntArray, T) -> T): N {
|
||||
override fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N {
|
||||
check(arg)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||
return new
|
||||
}
|
||||
|
||||
override fun combine(a: N, b: N, transform: R.(T, T) -> T): N {
|
||||
override fun combine(a: N, b: N, transform: C.(T, T) -> T): N {
|
||||
check(a, b)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||
return new
|
||||
}
|
||||
|
||||
override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap()
|
||||
override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap()
|
||||
override fun N.unaryMinus(): N = ndArray.negi().wrap()
|
||||
override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap()
|
||||
override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap()
|
||||
override fun N.div(k: Number): N = ndArray.divi(k).wrap()
|
||||
override fun N.minus(b: Number): N = ndArray.subi(b).wrap()
|
||||
override fun N.plus(b: Number): N = ndArray.addi(b).wrap()
|
||||
override fun N.times(k: Number): N = ndArray.muli(k).wrap()
|
||||
override fun Number.minus(b: N): N = b.ndArray.rsubi(this).wrap()
|
||||
}
|
||||
|
||||
interface INDArrayField<T, F, N> : NDField<T, F, N>,
|
||||
INDArrayRing<T, F, N> where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap()
|
||||
override fun Number.div(b: N): N = b.ndArray.rdivi(this).wrap()
|
||||
/**
|
||||
* Represents [NDSpace] over [INDArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
interface INDArraySpace<T, S, N> : NDSpace<T, S, N>, INDArrayAlgebra<T, S, N>
|
||||
where S : Space<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
|
||||
override val zero: N
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
|
||||
override fun add(a: N, b: N): N {
|
||||
check(a, b)
|
||||
return a.ndArray.add(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
override operator fun N.minus(b: N): N {
|
||||
check(this, b)
|
||||
return ndArray.sub(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
override operator fun N.unaryMinus(): N {
|
||||
check(this)
|
||||
return ndArray.neg().wrap()
|
||||
}
|
||||
|
||||
override fun multiply(a: N, k: Number): N {
|
||||
check(a)
|
||||
return a.ndArray.mul(k).wrap()
|
||||
}
|
||||
|
||||
override operator fun N.div(k: Number): N {
|
||||
check(this)
|
||||
return ndArray.div(k).wrap()
|
||||
}
|
||||
|
||||
override operator fun N.times(k: Number): N {
|
||||
check(this)
|
||||
return ndArray.mul(k).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDRing] over [INDArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
interface INDArrayRing<T, R, N> : NDRing<T, R, N>, INDArraySpace<T, R, N>
|
||||
where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
|
||||
override val one: N
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
|
||||
override fun multiply(a: N, b: N): N {
|
||||
check(a, b)
|
||||
return a.ndArray.mul(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
override operator fun N.minus(b: Number): N {
|
||||
check(this)
|
||||
return ndArray.sub(b).wrap()
|
||||
}
|
||||
|
||||
override operator fun N.plus(b: Number): N {
|
||||
check(this)
|
||||
return ndArray.add(b).wrap()
|
||||
}
|
||||
|
||||
override operator fun Number.minus(b: N): N {
|
||||
check(b)
|
||||
return b.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDField] over [INDArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
interface INDArrayField<T, F, N> : NDField<T, F, N>, INDArrayRing<T, F, N>
|
||||
where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
override fun divide(a: N, b: N): N {
|
||||
check(a, b)
|
||||
return a.ndArray.div(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
override operator fun Number.div(b: N): N {
|
||||
check(b)
|
||||
return b.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDField] over [INDArrayRealStructure].
|
||||
*/
|
||||
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
|
||||
INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
|
||||
override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure()
|
||||
override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap()
|
||||
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
|
||||
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
|
||||
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
|
||||
override fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rdivi(this).wrap()
|
||||
override fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure = arg.ndArray.rsubi(this).wrap()
|
||||
override fun INDArray.wrap(): INDArrayRealStructure = check(asRealStructure())
|
||||
override operator fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure {
|
||||
check(this)
|
||||
return ndArray.div(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure {
|
||||
check(arg)
|
||||
return arg.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
|
||||
override operator fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDField] over [INDArrayFloatStructure].
|
||||
*/
|
||||
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
|
||||
INDArrayField<Float, Field<Float>, INDArrayFloatStructure> {
|
||||
override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure()
|
||||
override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap()
|
||||
override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap()
|
||||
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
|
||||
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
|
||||
override fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rdivi(this).wrap()
|
||||
override fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure = arg.ndArray.rsubi(this).wrap()
|
||||
override fun INDArray.wrap(): INDArrayFloatStructure = check(asFloatStructure())
|
||||
override operator fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure {
|
||||
check(this)
|
||||
return ndArray.div(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure {
|
||||
check(arg)
|
||||
return arg.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
|
||||
override operator fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDRing] over [INDArrayIntStructure].
|
||||
*/
|
||||
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
|
||||
INDArrayRing<Int, Ring<Int>, INDArrayIntStructure> {
|
||||
override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure()
|
||||
override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap()
|
||||
override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap()
|
||||
override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap()
|
||||
override fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure = arg.ndArray.rsubi(this).wrap()
|
||||
override fun INDArray.wrap(): INDArrayIntStructure = check(asIntStructure())
|
||||
override operator fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [NDRing] over [INDArrayLongStructure].
|
||||
*/
|
||||
class LongINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Long> = LongRing) :
|
||||
INDArrayRing<Long, Ring<Long>, INDArrayLongStructure> {
|
||||
override fun INDArray.wrap(): INDArrayLongStructure = check(asLongStructure())
|
||||
override operator fun INDArrayLongStructure.plus(arg: Long): INDArrayLongStructure {
|
||||
check(this)
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayLongStructure.minus(arg: Long): INDArrayLongStructure {
|
||||
check(this)
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun INDArrayLongStructure.times(arg: Long): INDArrayLongStructure {
|
||||
check(this)
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
|
||||
override operator fun Long.minus(arg: INDArrayLongStructure): INDArrayLongStructure {
|
||||
check(arg)
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,24 @@ package scientifik.kmath.nd4j
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.shape.Shape
|
||||
|
||||
internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
||||
private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Iterator<IntArray> {
|
||||
private var i: Int = 0
|
||||
|
||||
override fun hasNext(): Boolean = i < iterateOver.length()
|
||||
|
||||
override fun next(): IntArray {
|
||||
val la = if (iterateOver.ordering() == 'c')
|
||||
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||
else
|
||||
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||
|
||||
return la.toIntArray()
|
||||
}
|
||||
}
|
||||
|
||||
internal fun INDArray.indicesIterator(): Iterator<IntArray> = INDArrayIndicesIterator(this)
|
||||
|
||||
private sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
||||
private var i: Int = 0
|
||||
|
||||
final override fun hasNext(): Boolean = i < iterateOver.length()
|
||||
@ -20,26 +37,26 @@ internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArra
|
||||
}
|
||||
}
|
||||
|
||||
internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
|
||||
private class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this)
|
||||
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = INDArrayRealIterator(this)
|
||||
|
||||
internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
|
||||
private class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.longIterator(): INDArrayLongIterator = INDArrayLongIterator(this)
|
||||
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(this)
|
||||
|
||||
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
||||
private class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
|
||||
}
|
||||
|
||||
internal fun INDArray.intIterator(): INDArrayIntIterator = INDArrayIntIterator(this)
|
||||
internal fun INDArray.intIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(this)
|
||||
|
||||
internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
|
||||
private class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.floatIterator() = INDArrayFloatIterator(this)
|
||||
internal fun INDArray.floatIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(this)
|
||||
|
@ -4,45 +4,77 @@ import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import scientifik.kmath.structures.MutableNDStructure
|
||||
import scientifik.kmath.structures.NDStructure
|
||||
|
||||
interface INDArrayStructure<T> : NDStructure<T> {
|
||||
val ndArray: INDArray
|
||||
/**
|
||||
* Represents a [NDStructure] wrapping an [INDArray] object.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
sealed class INDArrayStructure<T> : MutableNDStructure<T> {
|
||||
/**
|
||||
* The wrapped [INDArray].
|
||||
*/
|
||||
abstract val ndArray: INDArray
|
||||
|
||||
override val shape: IntArray
|
||||
get() = ndArray.shape().toIntArray()
|
||||
|
||||
fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||
}
|
||||
|
||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
||||
/**
|
||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as ints.
|
||||
*/
|
||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
|
||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [INDArrayIntStructure].
|
||||
*/
|
||||
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
|
||||
|
||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||
/**
|
||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as longs.
|
||||
*/
|
||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
||||
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
||||
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [INDArrayLongStructure].
|
||||
*/
|
||||
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
|
||||
|
||||
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>,
|
||||
MutableNDStructure<Double> {
|
||||
/**
|
||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as reals.
|
||||
*/
|
||||
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [INDArrayRealStructure].
|
||||
*/
|
||||
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
|
||||
|
||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>,
|
||||
MutableNDStructure<Float> {
|
||||
/**
|
||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as floats.
|
||||
*/
|
||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
|
||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [INDArrayFloatStructure].
|
||||
*/
|
||||
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)
|
||||
|
Loading…
Reference in New Issue
Block a user