Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116

Merged
CommanderTvis merged 50 commits from nd4j into dev 2020-10-29 19:58:53 +03:00
5 changed files with 147 additions and 18 deletions
Showing only changes of commit 783087982f - Show all commits

View File

@ -0,0 +1,89 @@
package scientifik.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.factory.Nd4j
import scientifik.kmath.operations.*
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDField
import scientifik.kmath.structures.NDRing
interface INDArrayRing<T, F, N> :
NDRing<T, F, N> where F : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
fun INDArray.wrap(): N
override val zero: N
get() = Nd4j.zeros(*shape).wrap()
override val one: N
get() = Nd4j.ones(*shape).wrap()
override fun produce(initializer: F.(IntArray) -> T): N {
val struct = Nd4j.create(*shape).wrap()
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
return struct
}
override fun map(arg: N, transform: F.(T) -> T): N {
val new = Nd4j.create(*shape)
Nd4j.copy(arg.ndArray, new)
val newStruct = new.wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
return newStruct
}
override fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N {
val new = Nd4j.create(*shape).wrap()
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) }
return new
}
override fun combine(a: N, b: N, transform: F.(T, T) -> T): N {
val new = Nd4j.create(*shape).wrap()
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) }
return new
}
override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap()
override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap()
override fun N.unaryMinus(): N = ndArray.negi().wrap()
override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap()
override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap()
override fun N.div(k: Number): N = ndArray.divi(k).wrap()
override fun N.minus(b: Number): N = ndArray.subi(b).wrap()
override fun N.plus(b: Number): N = ndArray.addi(b).wrap()
override fun N.times(k: Number): N = ndArray.muli(k).wrap()
}
interface INDArrayField<T, F, N> : NDField<T, F, N>,
INDArrayRing<T, F, N> where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap()
}
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure()
override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap()
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
override fun INDArrayRealStructure.div(k: Number): INDArrayRealStructure = ndArray.divi(k).wrap()
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
}
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
INDArrayField<Float, Field<Float>, INDArrayFloatStructure> {
override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure()
override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap()
override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap()
override fun INDArrayFloatStructure.div(k: Number): INDArrayFloatStructure = ndArray.divi(k).wrap()
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
}
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
INDArrayRing<Int, Ring<Int>, INDArrayIntStructure> {
override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure()
override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap()
override fun INDArrayIntStructure.div(k: Number): INDArrayIntStructure = ndArray.divi(k).wrap()
override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap()
override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap()
}

View File

@ -20,14 +20,19 @@ internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArra
}
}
internal class INDArrayDoubleIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
}
internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this)
internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
}
// TODO
//internal fun INDArray.longI
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices))
}

View File

@ -14,30 +14,35 @@ interface INDArrayStructure<T> : NDStructure<T> {
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}
inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
override fun get(index: IntArray): Int = ndArray.getInt(*index)
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)"
}
inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)"
}
inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double>, MutableNDStructure<Double> {
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>,
MutableNDStructure<Double> {
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayRealIterator(ndArray)
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)"
}
inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>, MutableNDStructure<Float> {
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>,
MutableNDStructure<Float> {
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)"
}
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)

View File

@ -0,0 +1,30 @@
package scientifik.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j
import scientifik.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class INDArrayAlgebraTest {
@Test
fun testProduce() {
val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
val expected = Nd4j.create(2, 2)!!.asRealStructure()
expected[intArrayOf(0, 0)] = 0.0
expected[intArrayOf(0, 1)] = 1.0
expected[intArrayOf(1, 0)] = 1.0
expected[intArrayOf(1, 1)] = 2.0
assertEquals(expected, res)
}
@Test
fun testMap() {
val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
val expected = Nd4j.create(2, 2)!!.asIntStructure()
expected[intArrayOf(0, 0)] = 3
expected[intArrayOf(0, 1)] = 3
expected[intArrayOf(1, 0)] = 3
expected[intArrayOf(1, 1)] = 3
assertEquals(expected, res)
}
}

View File

@ -10,7 +10,7 @@ internal class INDArrayStructureTest {
@Test
fun testElements() {
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct = INDArrayDoubleStructure(nd)
val struct = INDArrayRealStructure(nd)
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
assertEquals(listOf(1.0, 2.0, 3.0), res)
}
@ -25,15 +25,15 @@ internal class INDArrayStructureTest {
@Test
fun testEquals() {
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct1 = INDArrayDoubleStructure(nd1)
val struct1 = INDArrayRealStructure(nd1)
assertEquals(struct1, struct1)
assertNotEquals(struct1, null as INDArrayDoubleStructure?)
assertNotEquals(struct1, null as INDArrayRealStructure?)
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct2 = INDArrayDoubleStructure(nd2)
val struct2 = INDArrayRealStructure(nd2)
assertEquals(struct1, struct2)
assertEquals(struct2, struct1)
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct3 = INDArrayDoubleStructure(nd3)
val struct3 = INDArrayRealStructure(nd3)
assertEquals(struct2, struct3)
assertEquals(struct1, struct3)
}
@ -41,9 +41,9 @@ internal class INDArrayStructureTest {
@Test
fun testHashCode() {
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct1 = INDArrayDoubleStructure(nd1)
val struct1 = INDArrayRealStructure(nd1)
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct2 = INDArrayDoubleStructure(nd2)
val struct2 = INDArrayRealStructure(nd2)
assertEquals(struct1.hashCode(), struct2.hashCode())
}