Rollback making Structures inline, implement Algebras for NDArrayStructure
This commit is contained in:
parent
eb9d40fd2a
commit
783087982f
@ -0,0 +1,89 @@
|
||||
package scientifik.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import scientifik.kmath.operations.*
|
||||
import scientifik.kmath.structures.MutableNDStructure
|
||||
import scientifik.kmath.structures.NDField
|
||||
import scientifik.kmath.structures.NDRing
|
||||
|
||||
interface INDArrayRing<T, F, N> :
|
||||
NDRing<T, F, N> where F : Ring<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
fun INDArray.wrap(): N
|
||||
|
||||
override val zero: N
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
|
||||
override val one: N
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
|
||||
override fun produce(initializer: F.(IntArray) -> T): N {
|
||||
val struct = Nd4j.create(*shape).wrap()
|
||||
struct.elements().map(Pair<IntArray, T>::first).forEach { struct[it] = elementContext.initializer(it) }
|
||||
return struct
|
||||
}
|
||||
|
||||
override fun map(arg: N, transform: F.(T) -> T): N {
|
||||
val new = Nd4j.create(*shape)
|
||||
Nd4j.copy(arg.ndArray, new)
|
||||
val newStruct = new.wrap()
|
||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
||||
return newStruct
|
||||
}
|
||||
|
||||
override fun mapIndexed(arg: N, transform: F.(index: IntArray, T) -> T): N {
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(idx, arg[idx]) }
|
||||
return new
|
||||
}
|
||||
|
||||
override fun combine(a: N, b: N, transform: F.(T, T) -> T): N {
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.elements().forEach { (idx, _) -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||
return new
|
||||
}
|
||||
|
||||
override fun add(a: N, b: N): N = a.ndArray.addi(b.ndArray).wrap()
|
||||
override fun N.minus(b: N): N = ndArray.subi(b.ndArray).wrap()
|
||||
override fun N.unaryMinus(): N = ndArray.negi().wrap()
|
||||
override fun multiply(a: N, b: N): N = a.ndArray.muli(b.ndArray).wrap()
|
||||
override fun multiply(a: N, k: Number): N = a.ndArray.muli(k).wrap()
|
||||
override fun N.div(k: Number): N = ndArray.divi(k).wrap()
|
||||
override fun N.minus(b: Number): N = ndArray.subi(b).wrap()
|
||||
override fun N.plus(b: Number): N = ndArray.addi(b).wrap()
|
||||
override fun N.times(k: Number): N = ndArray.muli(k).wrap()
|
||||
}
|
||||
|
||||
interface INDArrayField<T, F, N> : NDField<T, F, N>,
|
||||
INDArrayRing<T, F, N> where F : Field<T>, N : INDArrayStructure<T>, N : MutableNDStructure<T> {
|
||||
override fun divide(a: N, b: N): N = a.ndArray.divi(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
class RealINDArrayField(override val shape: IntArray, override val elementContext: Field<Double> = RealField) :
|
||||
INDArrayField<Double, Field<Double>, INDArrayRealStructure> {
|
||||
override fun INDArray.wrap(): INDArrayRealStructure = asRealStructure()
|
||||
override fun INDArrayRealStructure.div(arg: Double): INDArrayRealStructure = ndArray.divi(arg).wrap()
|
||||
override fun INDArrayRealStructure.plus(arg: Double): INDArrayRealStructure = ndArray.addi(arg).wrap()
|
||||
override fun INDArrayRealStructure.div(k: Number): INDArrayRealStructure = ndArray.divi(k).wrap()
|
||||
override fun INDArrayRealStructure.minus(arg: Double): INDArrayRealStructure = ndArray.subi(arg).wrap()
|
||||
override fun INDArrayRealStructure.times(arg: Double): INDArrayRealStructure = ndArray.muli(arg).wrap()
|
||||
}
|
||||
|
||||
class FloatINDArrayField(override val shape: IntArray, override val elementContext: Field<Float> = FloatField) :
|
||||
INDArrayField<Float, Field<Float>, INDArrayFloatStructure> {
|
||||
override fun INDArray.wrap(): INDArrayFloatStructure = asFloatStructure()
|
||||
override fun INDArrayFloatStructure.div(arg: Float): INDArrayFloatStructure = ndArray.divi(arg).wrap()
|
||||
override fun INDArrayFloatStructure.plus(arg: Float): INDArrayFloatStructure = ndArray.addi(arg).wrap()
|
||||
override fun INDArrayFloatStructure.div(k: Number): INDArrayFloatStructure = ndArray.divi(k).wrap()
|
||||
override fun INDArrayFloatStructure.minus(arg: Float): INDArrayFloatStructure = ndArray.subi(arg).wrap()
|
||||
override fun INDArrayFloatStructure.times(arg: Float): INDArrayFloatStructure = ndArray.muli(arg).wrap()
|
||||
}
|
||||
|
||||
class IntINDArrayRing(override val shape: IntArray, override val elementContext: Ring<Int> = IntRing) :
|
||||
INDArrayRing<Int, Ring<Int>, INDArrayIntStructure> {
|
||||
override fun INDArray.wrap(): INDArrayIntStructure = asIntStructure()
|
||||
override fun INDArrayIntStructure.plus(arg: Int): INDArrayIntStructure = ndArray.addi(arg).wrap()
|
||||
override fun INDArrayIntStructure.div(k: Number): INDArrayIntStructure = ndArray.divi(k).wrap()
|
||||
override fun INDArrayIntStructure.minus(arg: Int): INDArrayIntStructure = ndArray.subi(arg).wrap()
|
||||
override fun INDArrayIntStructure.times(arg: Int): INDArrayIntStructure = ndArray.muli(arg).wrap()
|
||||
}
|
@ -20,14 +20,19 @@ internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArra
|
||||
}
|
||||
}
|
||||
|
||||
internal class INDArrayDoubleIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
|
||||
internal class INDArrayRealIterator(iterateOver: INDArray) : INDArrayIteratorBase<Double>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray): Double = iterateOver.getDouble(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.realIterator(): INDArrayRealIterator = INDArrayRealIterator(this)
|
||||
|
||||
internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBase<Long>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||
}
|
||||
|
||||
// TODO
|
||||
//internal fun INDArray.longI
|
||||
|
||||
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices))
|
||||
}
|
||||
|
@ -14,30 +14,35 @@ interface INDArrayStructure<T> : NDStructure<T> {
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||
}
|
||||
|
||||
inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
|
||||
override fun get(index: IntArray): Int = ndArray.getInt(*index)
|
||||
override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
|
||||
override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)"
|
||||
}
|
||||
|
||||
inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||
fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
|
||||
|
||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
||||
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
||||
override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)"
|
||||
|
||||
}
|
||||
|
||||
inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double>, MutableNDStructure<Double> {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
|
||||
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
|
||||
|
||||
data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>,
|
||||
MutableNDStructure<Double> {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayRealIterator(ndArray)
|
||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||
override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
|
||||
override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)"
|
||||
}
|
||||
|
||||
inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>, MutableNDStructure<Float> {
|
||||
fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)
|
||||
|
||||
data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>,
|
||||
MutableNDStructure<Float> {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
|
||||
override fun get(index: IntArray): Float = ndArray.getFloat(*index)
|
||||
override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
|
||||
override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)"
|
||||
}
|
||||
|
||||
fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)
|
||||
|
@ -0,0 +1,30 @@
|
||||
package scientifik.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import scientifik.kmath.operations.invoke
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class INDArrayAlgebraTest {
|
||||
@Test
|
||||
fun testProduce() {
|
||||
val res = (RealINDArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
||||
val expected = Nd4j.create(2, 2)!!.asRealStructure()
|
||||
expected[intArrayOf(0, 0)] = 0.0
|
||||
expected[intArrayOf(0, 1)] = 1.0
|
||||
expected[intArrayOf(1, 0)] = 1.0
|
||||
expected[intArrayOf(1, 1)] = 2.0
|
||||
assertEquals(expected, res)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testMap() {
|
||||
val res = (IntINDArrayRing(intArrayOf(2, 2))) { map(one) { it + it * 2 } }
|
||||
val expected = Nd4j.create(2, 2)!!.asIntStructure()
|
||||
expected[intArrayOf(0, 0)] = 3
|
||||
expected[intArrayOf(0, 1)] = 3
|
||||
expected[intArrayOf(1, 0)] = 3
|
||||
expected[intArrayOf(1, 1)] = 3
|
||||
assertEquals(expected, res)
|
||||
}
|
||||
}
|
@ -10,7 +10,7 @@ internal class INDArrayStructureTest {
|
||||
@Test
|
||||
fun testElements() {
|
||||
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct = INDArrayDoubleStructure(nd)
|
||||
val struct = INDArrayRealStructure(nd)
|
||||
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
|
||||
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
||||
}
|
||||
@ -25,15 +25,15 @@ internal class INDArrayStructureTest {
|
||||
@Test
|
||||
fun testEquals() {
|
||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct1 = INDArrayDoubleStructure(nd1)
|
||||
val struct1 = INDArrayRealStructure(nd1)
|
||||
assertEquals(struct1, struct1)
|
||||
assertNotEquals(struct1, null as INDArrayDoubleStructure?)
|
||||
assertNotEquals(struct1, null as INDArrayRealStructure?)
|
||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct2 = INDArrayDoubleStructure(nd2)
|
||||
val struct2 = INDArrayRealStructure(nd2)
|
||||
assertEquals(struct1, struct2)
|
||||
assertEquals(struct2, struct1)
|
||||
val nd3 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct3 = INDArrayDoubleStructure(nd3)
|
||||
val struct3 = INDArrayRealStructure(nd3)
|
||||
assertEquals(struct2, struct3)
|
||||
assertEquals(struct1, struct3)
|
||||
}
|
||||
@ -41,9 +41,9 @@ internal class INDArrayStructureTest {
|
||||
@Test
|
||||
fun testHashCode() {
|
||||
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct1 = INDArrayDoubleStructure(nd1)
|
||||
val struct1 = INDArrayRealStructure(nd1)
|
||||
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct2 = INDArrayDoubleStructure(nd2)
|
||||
val struct2 = INDArrayRealStructure(nd2)
|
||||
assertEquals(struct1.hashCode(), struct2.hashCode())
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user