package scientifik.kmath.nd4j

import org.nd4j.linalg.api.ndarray.INDArray
import scientifik.kmath.structures.MutableNDStructure
import scientifik.kmath.structures.NDStructure

interface INDArrayStructure<T> : NDStructure<T> {
    val ndArray: INDArray

    override val shape: IntArray
        get() = ndArray.shape().toIntArray()

    fun elementsIterator(): Iterator<Pair<IntArray, T>>
    override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}

data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructure<Int>, MutableNDStructure<Int> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Int>> = ndArray.intIterator()
    override fun get(index: IntArray): Int = ndArray.getInt(*index)
    override fun set(index: IntArray, value: Int): Unit = run { ndArray.putScalar(index, value) }
}

fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)

data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
    override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
}

fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)

data class INDArrayRealStructure(override val ndArray: INDArray) : INDArrayStructure<Double>,
    MutableNDStructure<Double> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
    override fun get(index: IntArray): Double = ndArray.getDouble(*index)
    override fun set(index: IntArray, value: Double): Unit = run { ndArray.putScalar(index, value) }
}

fun INDArray.asRealStructure(): INDArrayRealStructure = INDArrayRealStructure(this)

data class INDArrayFloatStructure(override val ndArray: INDArray) : INDArrayStructure<Float>,
    MutableNDStructure<Float> {
    override fun elementsIterator(): Iterator<Pair<IntArray, Float>> = ndArray.floatIterator()
    override fun get(index: IntArray): Float = ndArray.getFloat(*index)
    override fun set(index: IntArray, value: Float): Unit = run { ndArray.putScalar(index, value) }
}

fun INDArray.asFloatStructure(): INDArrayFloatStructure = INDArrayFloatStructure(this)