Implement the ND4J module for scalars
This commit is contained in:
parent
f364060acf
commit
3df9892de5
@ -11,6 +11,7 @@ allprojects {
|
|||||||
repositories {
|
repositories {
|
||||||
jcenter()
|
jcenter()
|
||||||
maven("https://dl.bintray.com/kotlin/kotlinx")
|
maven("https://dl.bintray.com/kotlin/kotlinx")
|
||||||
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
group = "scientifik"
|
group = "scientifik"
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
plugins {
|
plugins { id("scientifik.jvm") }
|
||||||
id("scientifik.jvm")
|
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-core"))
|
||||||
|
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
||||||
|
testImplementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
|
||||||
|
testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
|
testImplementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,19 @@
|
|||||||
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4j.linalg.api.shape.Shape
|
||||||
|
|
||||||
|
internal class INDArrayScalarsIterator(private val iterateOver: INDArray) : Iterator<Pair<IntArray, INDArray>> {
|
||||||
|
private var i: Int = 0
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = i < iterateOver.length()
|
||||||
|
|
||||||
|
override fun next(): Pair<IntArray, INDArray> {
|
||||||
|
val idx = if (iterateOver.ordering() == 'c')
|
||||||
|
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||||
|
else
|
||||||
|
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||||
|
|
||||||
|
return narrowToIntArray(idx) to iterateOver.getScalar(*idx)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,14 @@
|
|||||||
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import scientifik.kmath.structures.NDStructure
|
||||||
|
|
||||||
|
internal fun narrowToIntArray(la: LongArray): IntArray = IntArray(la.size) { la[it].toInt() }
|
||||||
|
|
||||||
|
data class ND4JStructure<T>(val ndArray: INDArray) : NDStructure<INDArray> {
|
||||||
|
override val shape: IntArray
|
||||||
|
get() = narrowToIntArray(ndArray.shape())
|
||||||
|
|
||||||
|
override fun get(index: IntArray): INDArray = ndArray.getScalar(*index)
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, INDArray>> = Sequence { INDArrayScalarsIterator(ndArray) }
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user