Refactor array functions
This commit is contained in:
parent
23b2ba9950
commit
d87dd3e717
@ -16,7 +16,7 @@ internal sealed class INDArrayIteratorBase<T>(protected val iterateOver: INDArra
|
|||||||
else
|
else
|
||||||
Shape.ind2sub(iterateOver, i++.toLong())!!
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||||
|
|
||||||
return narrowToIntArray(la) to getSingle(la)
|
return la.toIntArray() to getSingle(la)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,13 +30,16 @@ internal class INDArrayLongIterator(iterateOver: INDArray) : INDArrayIteratorBas
|
|||||||
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO
|
internal fun INDArray.longIterator(): INDArrayLongIterator = INDArrayLongIterator(this)
|
||||||
//internal fun INDArray.longI
|
|
||||||
|
|
||||||
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
internal class INDArrayIntIterator(iterateOver: INDArray) : INDArrayIteratorBase<Int>(iterateOver) {
|
||||||
override fun getSingle(indices: LongArray) = iterateOver.getInt(*narrowToIntArray(indices))
|
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.intIterator(): INDArrayIntIterator = INDArrayIntIterator(this)
|
||||||
|
|
||||||
internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
|
internal class INDArrayFloatIterator(iterateOver: INDArray) : INDArrayIteratorBase<Float>(iterateOver) {
|
||||||
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
override fun getSingle(indices: LongArray) = iterateOver.getFloat(*indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun INDArray.floatIterator() = INDArrayFloatIterator(this)
|
||||||
|
@ -8,7 +8,7 @@ interface INDArrayStructure<T> : NDStructure<T> {
|
|||||||
val ndArray: INDArray
|
val ndArray: INDArray
|
||||||
|
|
||||||
override val shape: IntArray
|
override val shape: IntArray
|
||||||
get() = narrowToIntArray(ndArray.shape())
|
get() = ndArray.shape().toIntArray()
|
||||||
|
|
||||||
fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||||
@ -24,7 +24,7 @@ fun INDArray.asIntStructure(): INDArrayIntStructure = INDArrayIntStructure(this)
|
|||||||
|
|
||||||
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
data class INDArrayLongStructure(override val ndArray: INDArray) : INDArrayStructure<Long> {
|
||||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = INDArrayLongIterator(ndArray)
|
||||||
override fun get(index: IntArray): Long = ndArray.getLong(*widenToLongArray(index))
|
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
|
fun INDArray.asLongStructure(): INDArrayLongStructure = INDArrayLongStructure(this)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package scientifik.kmath.nd4j
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
internal fun widenToLongArray(ia: IntArray): LongArray = LongArray(ia.size) { ia[it].toLong() }
|
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
|
||||||
internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() }
|
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
|
||||||
|
@ -27,4 +27,15 @@ internal class INDArrayAlgebraTest {
|
|||||||
expected[intArrayOf(1, 1)] = 3
|
expected[intArrayOf(1, 1)] = 3
|
||||||
assertEquals(expected, res)
|
assertEquals(expected, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAdd() {
|
||||||
|
val res = (IntINDArrayRing(intArrayOf(2, 2))) { one + 25 }
|
||||||
|
val expected = Nd4j.create(2, 2)!!.asIntStructure()
|
||||||
|
expected[intArrayOf(0, 0)] = 26
|
||||||
|
expected[intArrayOf(0, 1)] = 26
|
||||||
|
expected[intArrayOf(1, 0)] = 26
|
||||||
|
expected[intArrayOf(1, 1)] = 26
|
||||||
|
assertEquals(expected, res)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user