Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116

Merged
CommanderTvis merged 50 commits from nd4j into dev 2020-10-29 19:58:53 +03:00
2 changed files with 2 additions and 2 deletions
Showing only changes of commit b6bf741dbe - Show all commits

View File

@ -10,7 +10,7 @@ interface INDArrayStructureBase<T> : NDStructure<T> {
get() = narrowToIntArray(ndArray.shape()) get() = narrowToIntArray(ndArray.shape())
fun elementsIterator(): Iterator<Pair<IntArray, T>> fun elementsIterator(): Iterator<Pair<IntArray, T>>
override fun elements(): Sequence<Pair<IntArray, T>> = Sequence { elementsIterator() } override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
} }
data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase<Int> { data class INDArrayIntStructure(override val ndArray: INDArray) : INDArrayStructureBase<Int> {

View File

@ -9,7 +9,7 @@ internal class INDArrayStructureTest {
fun testElements() { fun testElements() {
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!! val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
val struct = INDArrayDoubleStructure(nd) val struct = INDArrayDoubleStructure(nd)
val res = struct.elements().map { it.second }.toList() val res = struct.elements().map(Pair<IntArray, Double>::second).toList()
assertEquals(listOf(1.0, 2.0, 3.0), res) assertEquals(listOf(1.0, 2.0, 3.0), res)
} }