Dev #194
@ -10,7 +10,7 @@ interface INDArrayStructureBase<T> : NDStructure<T> {
|
||||
get() = narrowToIntArray(ndArray.shape())
|
||||
|
||||
fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence { elementsIterator() }
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||
}
|
||||
|
||||
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase<Int> {
|
||||
|
@ -9,7 +9,7 @@ internal class INDArrayStructureTest {
|
||||
fun testElements() {
|
||||
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||
val struct = INDArrayDoubleStructure(nd)
|
||||
val res = struct.elements().map { it.second }.toList()
|
||||
val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
|
||||
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user