package scientifik.kmath.nd4j

import org.nd4j.linalg.api.ndarray.INDArray
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDStructure

interface INDArrayStructure<T> : NDStructure<T> {
    val ndArray: INDArray

    override val shape: IntArray
        get() = narrowToIntArray(ndArray.shape())

    fun elementsIterator(): Iterator<Pair<IntArray, T>>
    override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}

inline class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = INDArrayIntIterator(ndArray)
    override fun get(index: IntArray): Int = ndArray.getInt(*index)
    override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
    override fun toString(): String = "INDArrayIntStructure(ndArray=$ndArray)"
}

inline class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
    override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
    override fun toString(): String = "INDArrayLongStructure(ndArray=$ndArray)"

}

inline class INDArrayDoubleStructure(override val ndArray: INDArray) : INDArrayStructure<Double>, MutableNDStructure<Double> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = INDArrayDoubleIterator(ndArray)
    override fun get(index: IntArray): Double = ndArray.getDouble(*index)
    override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
    override fun toString(): String = "INDArrayDoubleStructure(ndArray=$ndArray)"
}

inline class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>, MutableNDStructure<Float> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = INDArrayFloatIterator(ndArray)
    override fun get(index: IntArray): Float = ndArray.getFloat(*index)
    override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
    override fun toString(): String = "INDArrayFloatStructure(ndArray=$ndArray)"
}