forked from kscience/kmath
Change package name, simplify exposed API types, update build snippet, minor refactor
This commit is contained in:
parent
202bc2e904
commit
2ee5d0f325
@ -26,9 +26,13 @@ dependencies {
|
|||||||
implementation(project(":kmath-prob"))
|
implementation(project(":kmath-prob"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
|
implementation(project(":kmath-nd4j"))
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
implementation("org.jetbrains.kotlinx:kotlinx-io-jvm:0.2.0-npm-dev-6")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
||||||
"benchmarksCompile"(sourceSets.main.get().output + sourceSets.main.get().compileClasspath) //sourceSets.main.output + sourceSets.main.runtimeClasspath
|
implementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
|
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
|
"benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8")
|
||||||
|
"benchmarksImplementation"(sourceSets.main.get().output + sourceSets.main.get().runtimeClasspath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure benchmark
|
// Configure benchmark
|
||||||
|
@ -15,8 +15,9 @@ public class BoxingNDField<T, F : Field<T>>(
|
|||||||
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
public fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
bufferFactory(size, initializer)
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
public override fun check(vararg elements: NDBuffer<T>) {
|
public override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
check(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
require(elements.all { it.strides == strides }) { "Element strides are not the same as context strides" }
|
||||||
|
return elements
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
public override fun produce(initializer: F.(IntArray) -> T): BufferedNDFieldElement<T, F> =
|
||||||
|
@ -5,8 +5,10 @@ import kscience.kmath.operations.*
|
|||||||
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
public interface BufferedNDAlgebra<T, C> : NDAlgebra<T, C, NDBuffer<T>> {
|
||||||
public val strides: Strides
|
public val strides: Strides
|
||||||
|
|
||||||
public override fun check(vararg elements: NDBuffer<T>): Unit =
|
public override fun check(vararg elements: NDBuffer<T>): Array<out NDBuffer<T>> {
|
||||||
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
|
require(elements.all { it.strides == strides }) { ("Strides mismatch") }
|
||||||
|
return elements
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert any [NDStructure] to buffered structure using strides from this context.
|
* Convert any [NDStructure] to buffered structure using strides from this context.
|
||||||
|
@ -11,7 +11,7 @@ import kscience.kmath.operations.Space
|
|||||||
* @property expected the expected shape.
|
* @property expected the expected shape.
|
||||||
* @property actual the actual shape.
|
* @property actual the actual shape.
|
||||||
*/
|
*/
|
||||||
public class ShapeMismatchException(val expected: IntArray, val actual: IntArray) :
|
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
||||||
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -13,26 +13,33 @@ This subproject implements the following features:
|
|||||||
>
|
>
|
||||||
> ```gradle
|
> ```gradle
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> mavenCentral()
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/scientifik' }
|
||||||
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
> maven { url 'https://dl.bintray.com/mipt-npm/dev' }
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation 'scientifik:kmath-nd4j:0.1.4-dev-8'
|
> implementation 'scientifik:kmath-nd4j:0.1.4-dev-8'
|
||||||
|
> implementation 'org.nd4j:nd4j-native-platform:1.0.0-beta7'
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
> **Gradle Kotlin DSL:**
|
> **Gradle Kotlin DSL:**
|
||||||
>
|
>
|
||||||
> ```kotlin
|
> ```kotlin
|
||||||
> repositories {
|
> repositories {
|
||||||
|
> mavenCentral()
|
||||||
> maven("https://dl.bintray.com/mipt-npm/scientifik")
|
> maven("https://dl.bintray.com/mipt-npm/scientifik")
|
||||||
> maven("https://dl.bintray.com/mipt-npm/dev")
|
> maven("https://dl.bintray.com/mipt-npm/dev")
|
||||||
> }
|
> }
|
||||||
>
|
>
|
||||||
> dependencies {
|
> dependencies {
|
||||||
> implementation("scientifik:kmath-nd4j:0.1.4-dev-8")
|
> implementation("scientifik:kmath-nd4j:0.1.4-dev-8")
|
||||||
|
> implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
> }
|
> }
|
||||||
> ```
|
> ```
|
||||||
|
>
|
||||||
|
> This distribution also needs an implementation of ND4J API. The ND4J Native Platform is usually the fastest one, so
|
||||||
|
> it is included to the snippet.
|
||||||
>
|
>
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("scientifik.jvm")
|
id("ru.mipt.npm.jvm")
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
@ -0,0 +1,284 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
import kscience.kmath.structures.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDAlgebra] over [INDArrayAlgebra].
|
||||||
|
*
|
||||||
|
* @param T the type of ND-structure element.
|
||||||
|
* @param C the type of the element context.
|
||||||
|
*/
|
||||||
|
public interface INDArrayAlgebra<T, C> : NDAlgebra<T, C, INDArrayStructure<T>> {
|
||||||
|
/**
|
||||||
|
* Wraps [INDArray] to [N].
|
||||||
|
*/
|
||||||
|
public fun INDArray.wrap(): INDArrayStructure<T>
|
||||||
|
|
||||||
|
public override fun produce(initializer: C.(IntArray) -> T): INDArrayStructure<T> {
|
||||||
|
val struct = Nd4j.create(*shape)!!.wrap()
|
||||||
|
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
||||||
|
return struct
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun map(arg: INDArrayStructure<T>, transform: C.(T) -> T): INDArrayStructure<T> {
|
||||||
|
check(arg)
|
||||||
|
val newStruct = arg.ndArray.dup().wrap()
|
||||||
|
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
||||||
|
return newStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun mapIndexed(
|
||||||
|
arg: INDArrayStructure<T>,
|
||||||
|
transform: C.(index: IntArray, T) -> T
|
||||||
|
): INDArrayStructure<T> {
|
||||||
|
check(arg)
|
||||||
|
val new = Nd4j.create(*shape).wrap()
|
||||||
|
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun combine(
|
||||||
|
a: INDArrayStructure<T>,
|
||||||
|
b: INDArrayStructure<T>,
|
||||||
|
transform: C.(T, T) -> T
|
||||||
|
): INDArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
val new = Nd4j.create(*shape).wrap()
|
||||||
|
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||||
|
return new
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDSpace] over [INDArrayStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param S the type of space of structure elements.
|
||||||
|
*/
|
||||||
|
public interface INDArraySpace<T, S> : NDSpace<T, S, INDArrayStructure<T>>, INDArrayAlgebra<T, S> where S : Space<T> {
|
||||||
|
public override val zero: INDArrayStructure<T>
|
||||||
|
get() = Nd4j.zeros(*shape).wrap()
|
||||||
|
|
||||||
|
public override fun add(a: INDArrayStructure<T>, b: INDArrayStructure<T>): INDArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
return a.ndArray.add(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<T>.minus(b: INDArrayStructure<T>): INDArrayStructure<T> {
|
||||||
|
check(this, b)
|
||||||
|
return ndArray.sub(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<T>.unaryMinus(): INDArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.neg().wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override fun multiply(a: INDArrayStructure<T>, k: Number): INDArrayStructure<T> {
|
||||||
|
check(a)
|
||||||
|
return a.ndArray.mul(k).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<T>.div(k: Number): INDArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.div(k).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<T>.times(k: Number): INDArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(k).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDRing] over [INDArrayStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param R the type of ring of structure elements.
|
||||||
|
*/
|
||||||
|
public interface INDArrayRing<T, R> : NDRing<T, R, INDArrayStructure<T>>, INDArraySpace<T, R> where R : Ring<T> {
|
||||||
|
public override val one: INDArrayStructure<T>
|
||||||
|
get() = Nd4j.ones(*shape).wrap()
|
||||||
|
|
||||||
|
public override fun multiply(a: INDArrayStructure<T>, b: INDArrayStructure<T>): INDArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
return a.ndArray.mul(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<T>.minus(b: Number): INDArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(b).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<T>.plus(b: Number): INDArrayStructure<T> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(b).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Number.minus(b: INDArrayStructure<T>): INDArrayStructure<T> {
|
||||||
|
check(b)
|
||||||
|
return b.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDField] over [INDArrayStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of the element contained in ND structure.
|
||||||
|
* @param N the type of ND structure.
|
||||||
|
* @param F the type field of structure elements.
|
||||||
|
*/
|
||||||
|
public interface INDArrayField<T, F> : NDField<T, F, INDArrayStructure<T>>, INDArrayRing<T, F> where F : Field<T> {
|
||||||
|
public override fun divide(a: INDArrayStructure<T>, b: INDArrayStructure<T>): INDArrayStructure<T> {
|
||||||
|
check(a, b)
|
||||||
|
return a.ndArray.div(b.ndArray).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Number.div(b: INDArrayStructure<T>): INDArrayStructure<T> {
|
||||||
|
check(b)
|
||||||
|
return b.ndArray.rdiv(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDField] over [INDArrayRealStructure].
|
||||||
|
*/
|
||||||
|
public class RealINDArrayField(public override val shape: IntArray) : INDArrayField<Double, RealField> {
|
||||||
|
public override val elementContext: RealField
|
||||||
|
get() = RealField
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): INDArrayStructure<Double> = check(asRealStructure())
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Double>.div(arg: Double): INDArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.div(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Double>.plus(arg: Double): INDArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Double>.minus(arg: Double): INDArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Double>.times(arg: Double): INDArrayStructure<Double> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Double.div(arg: INDArrayStructure<Double>): INDArrayStructure<Double> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rdiv(this).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Double.minus(arg: INDArrayStructure<Double>): INDArrayStructure<Double> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDField] over [INDArrayStructure] of [Float].
|
||||||
|
*/
|
||||||
|
public class FloatINDArrayField(public override val shape: IntArray) : INDArrayField<Float, FloatField> {
|
||||||
|
public override val elementContext: FloatField
|
||||||
|
get() = FloatField
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): INDArrayStructure<Float> = check(asFloatStructure())
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Float>.div(arg: Float): INDArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.div(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Float>.plus(arg: Float): INDArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Float>.minus(arg: Float): INDArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Float>.times(arg: Float): INDArrayStructure<Float> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Float.div(arg: INDArrayStructure<Float>): INDArrayStructure<Float> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rdiv(this).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Float.minus(arg: INDArrayStructure<Float>): INDArrayStructure<Float> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDRing] over [INDArrayIntStructure].
|
||||||
|
*/
|
||||||
|
public class IntINDArrayRing(public override val shape: IntArray) : INDArrayRing<Int, IntRing> {
|
||||||
|
public override val elementContext: IntRing
|
||||||
|
get() = IntRing
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): INDArrayStructure<Int> = check(asIntStructure())
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Int>.plus(arg: Int): INDArrayStructure<Int> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Int>.minus(arg: Int): INDArrayStructure<Int> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Int>.times(arg: Int): INDArrayStructure<Int> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Int.minus(arg: INDArrayStructure<Int>): INDArrayStructure<Int> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents [NDRing] over [INDArrayStructure] of [Long].
|
||||||
|
*/
|
||||||
|
public class LongINDArrayRing(public override val shape: IntArray) : INDArrayRing<Long, LongRing> {
|
||||||
|
public override val elementContext: LongRing
|
||||||
|
get() = LongRing
|
||||||
|
|
||||||
|
public override fun INDArray.wrap(): INDArrayStructure<Long> = check(asLongStructure())
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Long>.plus(arg: Long): INDArrayStructure<Long> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.add(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Long>.minus(arg: Long): INDArrayStructure<Long> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.sub(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun INDArrayStructure<Long>.times(arg: Long): INDArrayStructure<Long> {
|
||||||
|
check(this)
|
||||||
|
return ndArray.mul(arg).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public override operator fun Long.minus(arg: INDArrayStructure<Long>): INDArrayStructure<Long> {
|
||||||
|
check(arg)
|
||||||
|
return arg.ndArray.rsub(this).wrap()
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.nd4j
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.api.shape.Shape
|
import org.nd4j.linalg.api.shape.Shape
|
||||||
@ -18,7 +18,7 @@ private class INDArrayIndicesIterator(private val iterateOver: INDArray) : Itera
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun INDArray.indicesIterator(): Iterator<IntArray> = INDArrayIndicesIterator(this)
|
internal fun INDArray.indicesIterator(): Iterator<IntArray> = INDArrayIndicesIterator(this)
|
||||||
|
|
||||||
private sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
private sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArray) : Iterator<Pair<IntArray, T>> {
|
||||||
private var i: Int = 0
|
private var i: Int = 0
|
@ -0,0 +1,68 @@
|
|||||||
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import kscience.kmath.structures.MutableNDStructure
|
||||||
|
import kscience.kmath.structures.NDStructure
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a [NDStructure] wrapping an [INDArray] object.
|
||||||
|
*
|
||||||
|
* @param T the type of items.
|
||||||
|
*/
|
||||||
|
public sealed class INDArrayStructure<T> : MutableNDStructure<T> {
|
||||||
|
/**
|
||||||
|
* The wrapped [INDArray].
|
||||||
|
*/
|
||||||
|
public abstract val ndArray: INDArray
|
||||||
|
|
||||||
|
public override val shape: IntArray
|
||||||
|
get() = ndArray.shape().toIntArray()
|
||||||
|
|
||||||
|
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||||
|
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
|
||||||
|
public override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
|
||||||
|
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||||
|
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [INDArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asIntStructure(): INDArrayStructure<Int> = INDArrayIntStructure(this)
|
||||||
|
|
||||||
|
private data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
||||||
|
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
||||||
|
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [INDArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asLongStructure(): INDArrayStructure<Long> = INDArrayLongStructure(this)
|
||||||
|
|
||||||
|
private data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
||||||
|
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||||
|
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [INDArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asRealStructure(): INDArrayStructure<Double> = INDArrayRealStructure(this)
|
||||||
|
|
||||||
|
private data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>() {
|
||||||
|
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
|
||||||
|
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||||
|
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps this [INDArray] to [INDArrayStructure].
|
||||||
|
*/
|
||||||
|
public fun INDArray.asFloatStructure(): INDArrayStructure<Float> = INDArrayFloatStructure(this)
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.nd4j
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
|
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
|
||||||
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
|
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
|
@ -1,273 +0,0 @@
|
|||||||
package scientifik.kmath.nd4j
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
|
||||||
import scientifik.kmath.operations.*
|
|
||||||
import scientifik.kmath.structures.*
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDAlgebra] over [INDArrayAlgebra].
|
|
||||||
*
|
|
||||||
* @param T the type of ND-structure element.
|
|
||||||
* @param C the type of the element context.
|
|
||||||
* @param N the type of the structure.
|
|
||||||
*/
|
|
||||||
interface INDArrayAlgebra<T, C, N> : NDAlgebra<T, C, N> where N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
|
||||||
/**
|
|
||||||
* Wraps [INDArray] to [N].
|
|
||||||
*/
|
|
||||||
fun INDArray.wrap(): N
|
|
||||||
|
|
||||||
override fun produce(initializer: C.(IntArray) -> T): N {
|
|
||||||
val struct = Nd4j.create(*shape)!!.wrap()
|
|
||||||
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
|
||||||
return struct
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun map(arg: N, transform: C.(T) -> T): N {
|
|
||||||
check(arg)
|
|
||||||
val newStruct = arg.ndArray.dup().wrap()
|
|
||||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
|
||||||
return newStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun mapIndexed(arg: N, transform: C.(index: IntArray, T) -> T): N {
|
|
||||||
check(arg)
|
|
||||||
val new = Nd4j.create(*shape).wrap()
|
|
||||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
|
||||||
return new
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun combine(a: N, b: N, transform: C.(T, T) -> T): N {
|
|
||||||
check(a, b)
|
|
||||||
val new = Nd4j.create(*shape).wrap()
|
|
||||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
|
||||||
return new
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDSpace] over [INDArrayStructure].
|
|
||||||
*
|
|
||||||
* @param T the type of the element contained in ND structure.
|
|
||||||
* @param N the type of ND structure.
|
|
||||||
* @param S the type of space of structure elements.
|
|
||||||
*/
|
|
||||||
interface INDArraySpace<T, S, N> : NDSpace<T, S, N>, INDArrayAlgebra<T, S, N>
|
|
||||||
where S : Space<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
|
||||||
|
|
||||||
override val zero: N
|
|
||||||
get() = Nd4j.zeros(*shape).wrap()
|
|
||||||
|
|
||||||
override fun add(a: N, b: N): N {
|
|
||||||
check(a, b)
|
|
||||||
return a.ndArray.add(b.ndArray).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun N.minus(b: N): N {
|
|
||||||
check(this, b)
|
|
||||||
return ndArray.sub(b.ndArray).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun N.unaryMinus(): N {
|
|
||||||
check(this)
|
|
||||||
return ndArray.neg().wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun multiply(a: N, k: Number): N {
|
|
||||||
check(a)
|
|
||||||
return a.ndArray.mul(k).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun N.div(k: Number): N {
|
|
||||||
check(this)
|
|
||||||
return ndArray.div(k).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun N.times(k: Number): N {
|
|
||||||
check(this)
|
|
||||||
return ndArray.mul(k).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDRing] over [INDArrayStructure].
|
|
||||||
*
|
|
||||||
* @param T the type of the element contained in ND structure.
|
|
||||||
* @param N the type of ND structure.
|
|
||||||
* @param R the type of ring of structure elements.
|
|
||||||
*/
|
|
||||||
interface INDArrayRing<T, R, N> : NDRing<T, R, N>, INDArraySpace<T, R, N>
|
|
||||||
where R : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
|
||||||
|
|
||||||
override val one: N
|
|
||||||
get() = Nd4j.ones(*shape).wrap()
|
|
||||||
|
|
||||||
override fun multiply(a: N, b: N): N {
|
|
||||||
check(a, b)
|
|
||||||
return a.ndArray.mul(b.ndArray).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun N.minus(b: Number): N {
|
|
||||||
check(this)
|
|
||||||
return ndArray.sub(b).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun N.plus(b: Number): N {
|
|
||||||
check(this)
|
|
||||||
return ndArray.add(b).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Number.minus(b: N): N {
|
|
||||||
check(b)
|
|
||||||
return b.ndArray.rsub(this).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDField] over [INDArrayStructure].
|
|
||||||
*
|
|
||||||
* @param T the type of the element contained in ND structure.
|
|
||||||
* @param N the type of ND structure.
|
|
||||||
* @param F the type field of structure elements.
|
|
||||||
*/
|
|
||||||
interface INDArrayField<T, F, N> : NDField<T, F, N>, INDArrayRing<T, F, N>
|
|
||||||
where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
|
||||||
override fun divide(a: N, b: N): N {
|
|
||||||
check(a, b)
|
|
||||||
return a.ndArray.div(b.ndArray).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Number.div(b: N): N {
|
|
||||||
check(b)
|
|
||||||
return b.ndArray.rdiv(this).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDField] over [INDArrayRealStructure].
|
|
||||||
*/
|
|
||||||
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
|
|
||||||
INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
|
|
||||||
override fun INDArray.wrap(): INDArrayRealStructure = check(asRealStructure())
|
|
||||||
override operator fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.div(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.add(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.sub(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.mul(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Double.div(arg: INDArrayRealStructure): INDArrayRealStructure {
|
|
||||||
check(arg)
|
|
||||||
return arg.ndArray.rdiv(this).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Double.minus(arg: INDArrayRealStructure): INDArrayRealStructure {
|
|
||||||
check(arg)
|
|
||||||
return arg.ndArray.rsub(this).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDField] over [INDArrayFloatStructure].
|
|
||||||
*/
|
|
||||||
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
|
|
||||||
INDArrayField<Float, Field<Float>, INDArrayFloatStructure> {
|
|
||||||
override fun INDArray.wrap(): INDArrayFloatStructure = check(asFloatStructure())
|
|
||||||
override operator fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.div(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.add(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.sub(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.mul(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Float.div(arg: INDArrayFloatStructure): INDArrayFloatStructure {
|
|
||||||
check(arg)
|
|
||||||
return arg.ndArray.rdiv(this).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Float.minus(arg: INDArrayFloatStructure): INDArrayFloatStructure {
|
|
||||||
check(arg)
|
|
||||||
return arg.ndArray.rsub(this).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDRing] over [INDArrayIntStructure].
|
|
||||||
*/
|
|
||||||
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
|
|
||||||
INDArrayRing<Int, Ring<Int>, INDArrayIntStructure> {
|
|
||||||
override fun INDArray.wrap(): INDArrayIntStructure = check(asIntStructure())
|
|
||||||
override operator fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.add(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.sub(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.mul(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Int.minus(arg: INDArrayIntStructure): INDArrayIntStructure {
|
|
||||||
check(arg)
|
|
||||||
return arg.ndArray.rsub(this).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents [NDRing] over [INDArrayLongStructure].
|
|
||||||
*/
|
|
||||||
class LongINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Long> = LongRing) :
|
|
||||||
INDArrayRing<Long, Ring<Long>, INDArrayLongStructure> {
|
|
||||||
override fun INDArray.wrap(): INDArrayLongStructure = check(asLongStructure())
|
|
||||||
override operator fun INDArrayLongStructure.plus(arg: Long): INDArrayLongStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.add(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayLongStructure.minus(arg: Long): INDArrayLongStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.sub(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun INDArrayLongStructure.times(arg: Long): INDArrayLongStructure {
|
|
||||||
check(this)
|
|
||||||
return ndArray.mul(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Long.minus(arg: INDArrayLongStructure): INDArrayLongStructure {
|
|
||||||
check(arg)
|
|
||||||
return arg.ndArray.rsub(this).wrap()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,80 +0,0 @@
|
|||||||
package scientifik.kmath.nd4j
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
|
||||||
import scientifik.kmath.structures.MutableNDStructure
|
|
||||||
import scientifik.kmath.structures.NDStructure
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a [NDStructure] wrapping an [INDArray] object.
|
|
||||||
*
|
|
||||||
* @param T the type of items.
|
|
||||||
*/
|
|
||||||
sealed class INDArrayStructure<T> : MutableNDStructure<T> {
|
|
||||||
/**
|
|
||||||
* The wrapped [INDArray].
|
|
||||||
*/
|
|
||||||
abstract val ndArray: INDArray
|
|
||||||
|
|
||||||
override val shape: IntArray
|
|
||||||
get() = ndArray.shape().toIntArray()
|
|
||||||
|
|
||||||
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
|
||||||
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as ints.
|
|
||||||
*/
|
|
||||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>() {
|
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
|
|
||||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
|
||||||
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wraps this [INDArray] to [INDArrayIntStructure].
|
|
||||||
*/
|
|
||||||
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as longs.
|
|
||||||
*/
|
|
||||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long>() {
|
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
|
||||||
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
|
||||||
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wraps this [INDArray] to [INDArrayLongStructure].
|
|
||||||
*/
|
|
||||||
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as reals.
|
|
||||||
*/
|
|
||||||
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>() {
|
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
|
||||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
|
||||||
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wraps this [INDArray] to [INDArrayRealStructure].
|
|
||||||
*/
|
|
||||||
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a [NDStructure] over [INDArray] elements of which are accessed as floats.
|
|
||||||
*/
|
|
||||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>() {
|
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
|
|
||||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
|
||||||
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Wraps this [INDArray] to [INDArrayFloatStructure].
|
|
||||||
*/
|
|
||||||
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)
|
|
@ -1,15 +1,16 @@
|
|||||||
package scientifik.kmath.nd4j
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import scientifik.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
internal class INDArrayAlgebraTest {
|
internal class INDArrayAlgebraTest {
|
||||||
@Test
|
@Test
|
||||||
fun testProduce() {
|
fun testProduce() {
|
||||||
val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
||||||
val expected = Nd4j.create(2, 2)!!.asRealStructure()
|
val expected = (Nd4j.create(2, 2) ?: fail()).asRealStructure()
|
||||||
expected[intArrayOf(0, 0)] = 0.0
|
expected[intArrayOf(0, 0)] = 0.0
|
||||||
expected[intArrayOf(0, 1)] = 1.0
|
expected[intArrayOf(0, 1)] = 1.0
|
||||||
expected[intArrayOf(1, 0)] = 1.0
|
expected[intArrayOf(1, 0)] = 1.0
|
||||||
@ -20,7 +21,7 @@ internal class INDArrayAlgebraTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testMap() {
|
fun testMap() {
|
||||||
val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
|
val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
|
||||||
val expected = Nd4j.create(2, 2)!!.asIntStructure()
|
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||||
expected[intArrayOf(0, 0)] = 3
|
expected[intArrayOf(0, 0)] = 3
|
||||||
expected[intArrayOf(0, 1)] = 3
|
expected[intArrayOf(0, 1)] = 3
|
||||||
expected[intArrayOf(1, 0)] = 3
|
expected[intArrayOf(1, 0)] = 3
|
||||||
@ -31,7 +32,7 @@ internal class INDArrayAlgebraTest {
|
|||||||
@Test
|
@Test
|
||||||
fun testAdd() {
|
fun testAdd() {
|
||||||
val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 }
|
val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 }
|
||||||
val expected = Nd4j.create(2, 2)!!.asIntStructure()
|
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||||
expected[intArrayOf(0, 0)] = 26
|
expected[intArrayOf(0, 0)] = 26
|
||||||
expected[intArrayOf(0, 1)] = 26
|
expected[intArrayOf(0, 1)] = 26
|
||||||
expected[intArrayOf(1, 0)] = 26
|
expected[intArrayOf(1, 0)] = 26
|
@ -1,70 +1,71 @@
|
|||||||
package scientifik.kmath.nd4j
|
package kscience.kmath.nd4j
|
||||||
|
|
||||||
|
import kscience.kmath.structures.get
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import scientifik.kmath.structures.get
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertNotEquals
|
import kotlin.test.assertNotEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
internal class INDArrayStructureTest {
|
internal class INDArrayStructureTest {
|
||||||
@Test
|
@Test
|
||||||
fun testElements() {
|
fun testElements() {
|
||||||
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
val struct = INDArrayRealStructure(nd)
|
val struct = nd.asRealStructure()
|
||||||
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
|
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
|
||||||
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testShape() {
|
fun testShape() {
|
||||||
val nd = Nd4j.rand(10, 2, 3, 6)!!
|
val nd = Nd4j.rand(10, 2, 3, 6) ?: fail()
|
||||||
val struct = INDArrayLongStructure(nd)
|
val struct = nd.asRealStructure()
|
||||||
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testEquals() {
|
fun testEquals() {
|
||||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||||
val struct1 = INDArrayRealStructure(nd1)
|
val struct1 = nd1.asRealStructure()
|
||||||
assertEquals(struct1, struct1)
|
assertEquals(struct1, struct1)
|
||||||
assertNotEquals(struct1, null as INDArrayRealStructure?)
|
assertNotEquals(struct1 as Any?, null)
|
||||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||||
val struct2 = INDArrayRealStructure(nd2)
|
val struct2 = nd2.asRealStructure()
|
||||||
assertEquals(struct1, struct2)
|
assertEquals(struct1, struct2)
|
||||||
assertEquals(struct2, struct1)
|
assertEquals(struct2, struct1)
|
||||||
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0)) ?: fail()
|
||||||
val struct3 = INDArrayRealStructure(nd3)
|
val struct3 = nd3.asRealStructure()
|
||||||
assertEquals(struct2, struct3)
|
assertEquals(struct2, struct3)
|
||||||
assertEquals(struct1, struct3)
|
assertEquals(struct1, struct3)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testHashCode() {
|
fun testHashCode() {
|
||||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
|
||||||
val struct1 = INDArrayRealStructure(nd1)
|
val struct1 = nd1.asRealStructure()
|
||||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))?:fail()
|
||||||
val struct2 = INDArrayRealStructure(nd2)
|
val struct2 = nd2.asRealStructure()
|
||||||
assertEquals(struct1.hashCode(), struct2.hashCode())
|
assertEquals(struct1.hashCode(), struct2.hashCode())
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testDimension() {
|
fun testDimension() {
|
||||||
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
||||||
val struct = INDArrayFloatStructure(nd)
|
val struct = nd.asFloatStructure()
|
||||||
assertEquals(5, struct.dimension)
|
assertEquals(5, struct.dimension)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testGet() {
|
fun testGet() {
|
||||||
val nd = Nd4j.rand(10, 2, 3, 6)!!
|
val nd = Nd4j.rand(10, 2, 3, 6)?:fail()
|
||||||
val struct = INDArrayIntStructure(nd)
|
val struct = nd.asIntStructure()
|
||||||
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
|
assertEquals(nd.getInt(0, 0, 0, 0), struct[0, 0, 0, 0])
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSet() {
|
fun testSet() {
|
||||||
val nd = Nd4j.rand(17, 12, 4, 8)!!
|
val nd = Nd4j.rand(17, 12, 4, 8)!!
|
||||||
val struct = INDArrayIntStructure(nd)
|
val struct = nd.asLongStructure()
|
||||||
struct[intArrayOf(1, 2, 3, 4)] = 777
|
struct[intArrayOf(1, 2, 3, 4)] = 777
|
||||||
assertEquals(777, struct[1, 2, 3, 4])
|
assertEquals(777, struct[1, 2, 3, 4])
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user