Change package name, simplify exposed API types, update build snippet, minor refactor

This commit is contained in:
Iaroslav 2020-09-21 20:53:31 +07:00
parent 202bc2e904
commit 2ee5d0f325
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
14 changed files with 402 additions and 387 deletions

View File

@ -26,9 +26,13 @@ dependencies {
implementation(project(":kmath-prob")) implementation(project(":kmath-prob"))
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation(project(":kmath-nd4j"))
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6") implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath implementation("org.slf4j:slf4j-simple:1.7.30")
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
"benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8")
"benchmarksImplementation"(sourceSets.main.get().output + sourceSets.main.get().runtimeClasspath)
} }
// Configure benchmark // Configure benchmark

View File

@ -15,8 +15,9 @@ public class BoxingNDField<T, F : Field<T>>(
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> = public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer) bufferFactory(size, initializer)
public override fun check(vararg elements: NDBuffer<T>) { public override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" } require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
return elements
} }
public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> = public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =

View File

@ -5,8 +5,10 @@ import kscience.kmath.operations.*
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> { public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
public val strides: Strides public val strides: Strides
public override fun check(vararg elements: NDBuffer<T>): Unit = public override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
require(elements.all { it.strides == strides }) { ("Strides mismatch") } require(elements.all { it.strides == strides }) { ("Strides mismatch") }
return elements
}
/** /**
* Convert any [NDStructure] to buffered structure using strides from this context. * Convert any [NDStructure] to buffered structure using strides from this context.

View File

@ -11,7 +11,7 @@ import kscience.kmath.operations.Space
* @property expected the expected shape. * @property expected the expected shape.
* @property actual the actual shape. * @property actual the actual shape.
*/ */
public class ShapeMismatchException(val expected: IntArray, val actual: IntArray) : public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
/** /**

View File

@ -13,27 +13,34 @@ This subproject implements the following features:
> >
> ```gradle > ```gradle
> repositories { > repositories {
> mavenCentral()
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' } > maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
> maven { url 'https://dl.bintray.com/mipt-npm/dev' } > maven { url 'https://dl.bintray.com/mipt-npm/dev' }
> } > }
> >
> dependencies { > dependencies {
> implementation 'scientifik:kmath-nd4j:0.1.4-dev-8' > implementation 'scientifik:kmath-nd4j:0.1.4-dev-8'
> implementation 'org.nd4j:nd4j-native-platform:1.0.0-beta7'
> } > }
> ``` > ```
> **Gradle Kotlin DSL:** > **Gradle Kotlin DSL:**
> >
> ```kotlin > ```kotlin
> repositories { > repositories {
> mavenCentral()
> maven("https://dl.bintray.com/mipt-npm/scientifik") > maven("https://dl.bintray.com/mipt-npm/scientifik")
> maven("https://dl.bintray.com/mipt-npm/dev") > maven("https://dl.bintray.com/mipt-npm/dev")
> } > }
> >
> dependencies { > dependencies {
> implementation("scientifik:kmath-nd4j:0.1.4-dev-8") > implementation("scientifik:kmath-nd4j:0.1.4-dev-8")
> implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
> } > }
> ``` > ```
> >
> This distribution also needs an implementation of ND4J API. The ND4J Native Platform is usually the fastest one, so
> it is included to the snippet.
>
## Examples ## Examples

View File

@ -1,5 +1,5 @@
plugins { plugins {
id("scientifik.jvm") id("ru.mipt.npm.jvm")
} }
dependencies { dependencies {

View File

@ -0,0 +1,284 @@
package kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import kscience.kmath.operations.*
import kscience.kmath.structures.*
/**
* Represents [NDAlgebra] over [INDArrayAlgebra].
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
*/
public interface INDArrayAlgebra<T, C> : NDAlgebra<T, C, INDArrayStructure<T>> {
/**
* Wraps [INDArray] to [N].
*/
public fun INDArray.wrap(): INDArrayStructure<T>
public override fun produce(initializer: C.(IntArray) -> T): INDArrayStructure<T> {
val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
return struct
}
public override fun map(arg: INDArrayStructure<T>, transform: C.(T) -> T): INDArrayStructure<T> {
check(arg)
val newStruct = arg.ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
return newStruct
}
public override fun mapIndexed(
arg: INDArrayStructure<T>,
transform: C.(index: IntArray, T) -> T
): INDArrayStructure<T> {
check(arg)
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
return new
}
public override fun combine(
a: INDArrayStructure<T>,
b: INDArrayStructure<T>,
transform: C.(T, T) -> T
): INDArrayStructure<T> {
check(a, b)
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
return new
}
}
/**
* Represents [NDSpace] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param S the type of space of structure elements.
*/
public interface INDArraySpace<T, S> : NDSpace<T, S, INDArrayStructure<T>>, INDArrayAlgebra<T, S> where S : Space<T> {
public override val zero: INDArrayStructure<T>
get() = Nd4j.zeros(*shape).wrap()
public override fun add(a: INDArrayStructure<T>, b: INDArrayStructure<T>): INDArrayStructure<T> {
check(a, b)
return a.ndArray.add(b.ndArray).wrap()
}
public override operator fun INDArrayStructure<T>.minus(b: INDArrayStructure<T>): INDArrayStructure<T> {
check(this, b)
return ndArray.sub(b.ndArray).wrap()
}
public override operator fun INDArrayStructure<T>.unaryMinus(): INDArrayStructure<T> {
check(this)
return ndArray.neg().wrap()
}
public override fun multiply(a: INDArrayStructure<T>, k: Number): INDArrayStructure<T> {
check(a)
return a.ndArray.mul(k).wrap()
}
public override operator fun INDArrayStructure<T>.div(k: Number): INDArrayStructure<T> {
check(this)
return ndArray.div(k).wrap()
}
public override operator fun INDArrayStructure<T>.times(k: Number): INDArrayStructure<T> {
check(this)
return ndArray.mul(k).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param R the type of ring of structure elements.
*/
public interface INDArrayRing<T, R> : NDRing<T, R, INDArrayStructure<T>>, INDArraySpace<T, R> where R : Ring<T> {
public override val one: INDArrayStructure<T>
get() = Nd4j.ones(*shape).wrap()
public override fun multiply(a: INDArrayStructure<T>, b: INDArrayStructure<T>): INDArrayStructure<T> {
check(a, b)
return a.ndArray.mul(b.ndArray).wrap()
}
public override operator fun INDArrayStructure<T>.minus(b: Number): INDArrayStructure<T> {
check(this)
return ndArray.sub(b).wrap()
}
public override operator fun INDArrayStructure<T>.plus(b: Number): INDArrayStructure<T> {
check(this)
return ndArray.add(b).wrap()
}
public override operator fun Number.minus(b: INDArrayStructure<T>): INDArrayStructure<T> {
check(b)
return b.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
public interface INDArrayField<T, F> : NDField<T, F, INDArrayStructure<T>>, INDArrayRing<T, F> where F : Field<T> {
public override fun divide(a: INDArrayStructure<T>, b: INDArrayStructure<T>): INDArrayStructure<T> {
check(a, b)
return a.ndArray.div(b.ndArray).wrap()
}
public override operator fun Number.div(b: INDArrayStructure<T>): INDArrayStructure<T> {
check(b)
return b.ndArray.rdiv(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayRealStructure].
*/
public class RealINDArrayField(public override val shape: IntArray) : INDArrayField<Double, RealField> {
public override val elementContext: RealField
get() = RealField
public override fun INDArray.wrap(): INDArrayStructure<Double> = check(asRealStructure())
public override operator fun INDArrayStructure<Double>.div(arg: Double): INDArrayStructure<Double> {
check(this)
return ndArray.div(arg).wrap()
}
public override operator fun INDArrayStructure<Double>.plus(arg: Double): INDArrayStructure<Double> {
check(this)
return ndArray.add(arg).wrap()
}
public override operator fun INDArrayStructure<Double>.minus(arg: Double): INDArrayStructure<Double> {
check(this)
return ndArray.sub(arg).wrap()
}
public override operator fun INDArrayStructure<Double>.times(arg: Double): INDArrayStructure<Double> {
check(this)
return ndArray.mul(arg).wrap()
}
public override operator fun Double.div(arg: INDArrayStructure<Double>): INDArrayStructure<Double> {
check(arg)
return arg.ndArray.rdiv(this).wrap()
}
public override operator fun Double.minus(arg: INDArrayStructure<Double>): INDArrayStructure<Double> {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayStructure] of [Float].
*/
public class FloatINDArrayField(public override val shape: IntArray) : INDArrayField<Float, FloatField> {
public override val elementContext: FloatField
get() = FloatField
public override fun INDArray.wrap(): INDArrayStructure<Float> = check(asFloatStructure())
public override operator fun INDArrayStructure<Float>.div(arg: Float): INDArrayStructure<Float> {
check(this)
return ndArray.div(arg).wrap()
}
public override operator fun INDArrayStructure<Float>.plus(arg: Float): INDArrayStructure<Float> {
check(this)
return ndArray.add(arg).wrap()
}
public override operator fun INDArrayStructure<Float>.minus(arg: Float): INDArrayStructure<Float> {
check(this)
return ndArray.sub(arg).wrap()
}
public override operator fun INDArrayStructure<Float>.times(arg: Float): INDArrayStructure<Float> {
check(this)
return ndArray.mul(arg).wrap()
}
public override operator fun Float.div(arg: INDArrayStructure<Float>): INDArrayStructure<Float> {
check(arg)
return arg.ndArray.rdiv(this).wrap()
}
public override operator fun Float.minus(arg: INDArrayStructure<Float>): INDArrayStructure<Float> {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayIntStructure].
*/
public class IntINDArrayRing(public override val shape: IntArray) : INDArrayRing<Int, IntRing> {
public override val elementContext: IntRing
get() = IntRing
public override fun INDArray.wrap(): INDArrayStructure<Int> = check(asIntStructure())
public override operator fun INDArrayStructure<Int>.plus(arg: Int): INDArrayStructure<Int> {
check(this)
return ndArray.add(arg).wrap()
}
public override operator fun INDArrayStructure<Int>.minus(arg: Int): INDArrayStructure<Int> {
check(this)
return ndArray.sub(arg).wrap()
}
public override operator fun INDArrayStructure<Int>.times(arg: Int): INDArrayStructure<Int> {
check(this)
return ndArray.mul(arg).wrap()
}
public override operator fun Int.minus(arg: INDArrayStructure<Int>): INDArrayStructure<Int> {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayStructure] of [Long].
*/
public class LongINDArrayRing(public override val shape: IntArray) : INDArrayRing<Long, LongRing> {
public override val elementContext: LongRing
get() = LongRing
public override fun INDArray.wrap(): INDArrayStructure<Long> = check(asLongStructure())
public override operator fun INDArrayStructure<Long>.plus(arg: Long): INDArrayStructure<Long> {
check(this)
return ndArray.add(arg).wrap()
}
public override operator fun INDArrayStructure<Long>.minus(arg: Long): INDArrayStructure<Long> {
check(this)
return ndArray.sub(arg).wrap()
}
public override operator fun INDArrayStructure<Long>.times(arg: Long): INDArrayStructure<Long> {
check(this)
return ndArray.mul(arg).wrap()
}
public override operator fun Long.minus(arg: INDArrayStructure<Long>): INDArrayStructure<Long> {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}

View File

@ -1,4 +1,4 @@
package scientifik.kmath.nd4j package kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.shape.Shape import org.nd4j.linalg.api.shape.Shape

View File

@ -0,0 +1,68 @@
package kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import kscience.kmath.structures.MutableNDStructure
import kscience.kmath.structures.NDStructure
/**
* Represents a [NDStructure] wrapping an [INDArray] object.
*
* @param T the type of items.
*/
public sealed class INDArrayStructure<T> : MutableNDStructure<T> {
/**
* The wrapped [INDArray].
*/
public abstract val ndArray: INDArray
public override val shape: IntArray
get() = ndArray.shape().toIntArray()
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
public override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}
private data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
override fun get(index: IntArray): Int = ndArray.getInt(*index)
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayStructure].
*/
public fun INDArray.asIntStructure(): INDArrayStructure<Int> = INDArrayIntStructure(this)
private data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
}
/**
* Wraps this [INDArray] to [INDArrayStructure].
*/
public fun INDArray.asLongStructure(): INDArrayStructure<Long> = INDArrayLongStructure(this)
private data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayStructure].
*/
public fun INDArray.asRealStructure(): INDArrayStructure<Double> = INDArrayRealStructure(this)
private data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayStructure].
*/
public fun INDArray.asFloatStructure(): INDArrayStructure<Float> = INDArrayFloatStructure(this)

View File

@ -1,4 +1,4 @@
package scientifik.kmath.nd4j package kscience.kmath.nd4j
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() } internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }

View File

@ -1,273 +0,0 @@
package scientifik.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import scientifik.kmath.operations.*
import scientifik.kmath.structures.*
/**
* Represents [NDAlgebra] over [INDArrayAlgebra].
*
* @param T the type of ND-structure element.
* @param C the type of the element context.
* @param N the type of the structure.
*/
interface INDArrayAlgebra<T, C, N> : NDAlgebra<T, C, N> where N : INDArrayStructure<T>, N : MutableNDStructure<T> {
/**
* Wraps [INDArray] to [N].
*/
fun INDArray.wrap(): N
override fun produce(initializer: C.(IntArray) -> T): N {
val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
return struct
}
override fun map(arg: N, transform: C.(T) -> T): N {
check(arg)
val newStruct = arg.ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
return newStruct
}
override fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N {
check(arg)
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
return new
}
override fun combine(a: N, b: N, transform: C.(T, T) -> T): N {
check(a, b)
val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
return new
}
}
/**
* Represents [NDSpace] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param S the type of space of structure elements.
*/
interface INDArraySpace<T, S, N> : NDSpace<T, S, N>, INDArrayAlgebra<T, S, N>
where S : Space<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override val zero: N
get() = Nd4j.zeros(*shape).wrap()
override fun add(a: N, b: N): N {
check(a, b)
return a.ndArray.add(b.ndArray).wrap()
}
override operator fun N.minus(b: N): N {
check(this, b)
return ndArray.sub(b.ndArray).wrap()
}
override operator fun N.unaryMinus(): N {
check(this)
return ndArray.neg().wrap()
}
override fun multiply(a: N, k: Number): N {
check(a)
return a.ndArray.mul(k).wrap()
}
override operator fun N.div(k: Number): N {
check(this)
return ndArray.div(k).wrap()
}
override operator fun N.times(k: Number): N {
check(this)
return ndArray.mul(k).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param R the type of ring of structure elements.
*/
interface INDArrayRing<T, R, N> : NDRing<T, R, N>, INDArraySpace<T, R, N>
where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override val one: N
get() = Nd4j.ones(*shape).wrap()
override fun multiply(a: N, b: N): N {
check(a, b)
return a.ndArray.mul(b.ndArray).wrap()
}
override operator fun N.minus(b: Number): N {
check(this)
return ndArray.sub(b).wrap()
}
override operator fun N.plus(b: Number): N {
check(this)
return ndArray.add(b).wrap()
}
override operator fun Number.minus(b: N): N {
check(b)
return b.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
interface INDArrayField<T, F, N> : NDField<T, F, N>, INDArrayRing<T, F, N>
where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override fun divide(a: N, b: N): N {
check(a, b)
return a.ndArray.div(b.ndArray).wrap()
}
override operator fun Number.div(b: N): N {
check(b)
return b.ndArray.rdiv(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayRealStructure].
*/
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
override fun INDArray.wrap(): INDArrayRealStructure = check(asRealStructure())
override operator fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.div(arg).wrap()
}
override operator fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure {
check(arg)
return arg.ndArray.rdiv(this).wrap()
}
override operator fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDField] over [INDArrayFloatStructure].
*/
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
INDArrayField<Float, Field<Float>, INDArrayFloatStructure> {
override fun INDArray.wrap(): INDArrayFloatStructure = check(asFloatStructure())
override operator fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.div(arg).wrap()
}
override operator fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure {
check(arg)
return arg.ndArray.rdiv(this).wrap()
}
override operator fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayIntStructure].
*/
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
INDArrayRing<Int, Ring<Int>, INDArrayIntStructure> {
override fun INDArray.wrap(): INDArrayIntStructure = check(asIntStructure())
override operator fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}
/**
* Represents [NDRing] over [INDArrayLongStructure].
*/
class LongINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Long> = LongRing) :
INDArrayRing<Long, Ring<Long>, INDArrayLongStructure> {
override fun INDArray.wrap(): INDArrayLongStructure = check(asLongStructure())
override operator fun INDArrayLongStructure.plus(arg: Long): INDArrayLongStructure {
check(this)
return ndArray.add(arg).wrap()
}
override operator fun INDArrayLongStructure.minus(arg: Long): INDArrayLongStructure {
check(this)
return ndArray.sub(arg).wrap()
}
override operator fun INDArrayLongStructure.times(arg: Long): INDArrayLongStructure {
check(this)
return ndArray.mul(arg).wrap()
}
override operator fun Long.minus(arg: INDArrayLongStructure): INDArrayLongStructure {
check(arg)
return arg.ndArray.rsub(this).wrap()
}
}

View File

@ -1,80 +0,0 @@
package scientifik.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDStructure
/**
* Represents a [NDStructure] wrapping an [INDArray] object.
*
* @param T the type of items.
*/
sealed class INDArrayStructure<T> : MutableNDStructure<T> {
/**
* The wrapped [INDArray].
*/
abstract val ndArray: INDArray
override val shape: IntArray
get() = ndArray.shape().toIntArray()
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as ints.
*/
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
override fun get(index: IntArray): Int = ndArray.getInt(*index)
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayIntStructure].
*/
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as longs.
*/
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
}
/**
* Wraps this [INDArray] to [INDArrayLongStructure].
*/
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as reals.
*/
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayRealStructure].
*/
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
/**
* Represents a [NDStructure] over [INDArray] elements of which are accessed as floats.
*/
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
}
/**
* Wraps this [INDArray] to [INDArrayFloatStructure].
*/
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)

View File

@ -1,15 +1,16 @@
package scientifik.kmath.nd4j package kscience.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import scientifik.kmath.operations.invoke import kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.fail
internal class INDArrayAlgebraTest { internal class INDArrayAlgebraTest {
@Test @Test
fun testProduce() { fun testProduce() {
val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
val expected = Nd4j.create(2, 2)!!.asRealStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure()
expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 0)] = 0.0
expected[intArrayOf(0, 1)] = 1.0 expected[intArrayOf(0, 1)] = 1.0
expected[intArrayOf(1, 0)] = 1.0 expected[intArrayOf(1, 0)] = 1.0
@ -20,7 +21,7 @@ internal class INDArrayAlgebraTest {
@Test @Test
fun testMap() { fun testMap() {
val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } } val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
val expected = Nd4j.create(2, 2)!!.asIntStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 0)] = 3
expected[intArrayOf(0, 1)] = 3 expected[intArrayOf(0, 1)] = 3
expected[intArrayOf(1, 0)] = 3 expected[intArrayOf(1, 0)] = 3
@ -31,7 +32,7 @@ internal class INDArrayAlgebraTest {
@Test @Test
fun testAdd() { fun testAdd() {
val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 } val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 }
val expected = Nd4j.create(2, 2)!!.asIntStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 0)] = 26
expected[intArrayOf(0, 1)] = 26 expected[intArrayOf(0, 1)] = 26
expected[intArrayOf(1, 0)] = 26 expected[intArrayOf(1, 0)] = 26

View File

@ -1,70 +1,71 @@
package scientifik.kmath.nd4j package kscience.kmath.nd4j
import kscience.kmath.structures.get
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import scientifik.kmath.structures.get
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNotEquals import kotlin.test.assertNotEquals
import kotlin.test.fail
internal class INDArrayStructureTest { internal class INDArrayStructureTest {
@Test @Test
fun testElements() { fun testElements() {
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct = INDArrayRealStructure(nd) val struct = nd.asRealStructure()
val res = struct.elements().map(Pair<IntArray, Double>::second).toList() val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
assertEquals(listOf(1.0, 2.0, 3.0), res) assertEquals(listOf(1.0, 2.0, 3.0), res)
} }
@Test @Test
fun testShape() { fun testShape() {
val nd = Nd4j.rand(10, 2, 3, 6)!! val nd = Nd4j.rand(10, 2, 3, 6) ?: fail()
val struct = INDArrayLongStructure(nd) val struct = nd.asRealStructure()
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList()) assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
} }
@Test @Test
fun testEquals() { fun testEquals() {
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
val struct1 = INDArrayRealStructure(nd1) val struct1 = nd1.asRealStructure()
assertEquals(struct1, struct1) assertEquals(struct1, struct1)
assertNotEquals(struct1, null as INDArrayRealStructure?) assertNotEquals(struct1 as Any?, null)
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
val struct2 = INDArrayRealStructure(nd2) val struct2 = nd2.asRealStructure()
assertEquals(struct1, struct2) assertEquals(struct1, struct2)
assertEquals(struct2, struct1) assertEquals(struct2, struct1)
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
val struct3 = INDArrayRealStructure(nd3) val struct3 = nd3.asRealStructure()
assertEquals(struct2, struct3) assertEquals(struct2, struct3)
assertEquals(struct1, struct3) assertEquals(struct1, struct3)
} }
@Test @Test
fun testHashCode() { fun testHashCode() {
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
val struct1 = INDArrayRealStructure(nd1) val struct1 = nd1.asRealStructure()
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
val struct2 = INDArrayRealStructure(nd2) val struct2 = nd2.asRealStructure()
assertEquals(struct1.hashCode(), struct2.hashCode()) assertEquals(struct1.hashCode(), struct2.hashCode())
} }
@Test @Test
fun testDimension() { fun testDimension() {
val nd = Nd4j.rand(8, 16, 3, 7, 1)!! val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
val struct = INDArrayFloatStructure(nd) val struct = nd.asFloatStructure()
assertEquals(5, struct.dimension) assertEquals(5, struct.dimension)
} }
@Test @Test
fun testGet() { fun testGet() {
val nd = Nd4j.rand(10, 2, 3, 6)!! val nd = Nd4j.rand(10, 2, 3, 6)?:fail()
val struct = INDArrayIntStructure(nd) val struct = nd.asIntStructure()
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0]) assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
} }
@Test @Test
fun testSet() { fun testSet() {
val nd = Nd4j.rand(17, 12, 4, 8)!! val nd = Nd4j.rand(17, 12, 4, 8)!!
val struct = INDArrayIntStructure(nd) val struct = nd.asLongStructure()
struct[intArrayOf(1, 2, 3, 4)] = 777 struct[intArrayOf(1, 2, 3, 4)] = 777
assertEquals(777, struct[1, 2, 3, 4]) assertEquals(777, struct[1, 2, 3, 4])
} }