Implement kmath-nd4j: module that implements NDStructure for INDArray of ND4J #116
@ -0,0 +1,38 @@
|
|||||||
|
package scientifik.kmath.nd4j
|
||||||
|
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
internal class INDArrayStructureTest {
|
||||||
|
@Test
|
||||||
|
fun testElements() {
|
||||||
|
val nd = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct = INDArrayDoubleStructure(nd)
|
||||||
|
val res = struct.elements().map { it.second }.toList()
|
||||||
|
assertEquals(listOf(1.0, 2.0, 3.0), res)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testShape() {
|
||||||
|
val nd = Nd4j.rand(10, 2, 3, 6)!!
|
||||||
|
val struct = INDArrayIntStructure(nd)
|
||||||
|
assertEquals(intArrayOf(10, 2, 3, 6).toList(), struct.shape.toList())
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testEquals() {
|
||||||
|
val nd1 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct1 = INDArrayDoubleStructure(nd1)
|
||||||
|
val nd2 = Nd4j.create(doubleArrayOf(1.0, 2.0, 3.0))!!
|
||||||
|
val struct2 = INDArrayDoubleStructure(nd2)
|
||||||
|
assertEquals(struct1, struct2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testDimension() {
|
||||||
|
val nd = Nd4j.rand(8, 16, 3, 7, 1)!!
|
||||||
|
val struct = INDArrayIntStructure(nd)
|
||||||
|
assertEquals(5, struct.dimension)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user