tensors testing
This commit is contained in:
parent
face60824d
commit
ea6cd01b89
@ -29,12 +29,51 @@ internal fun NoaFloat.testingCopying(device: Device = Device.CPU): Unit {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||||
|
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
|
||||||
|
val viewTensor = tensor.view(intArrayOf(2, 3))
|
||||||
|
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
|
||||||
|
viewTensor[intArrayOf(0, 0)] = 10
|
||||||
|
assertEquals(tensor[intArrayOf(0)], 10)
|
||||||
|
}
|
||||||
|
|
||||||
class TestTensor {
|
class TestTensor {
|
||||||
@Test
|
@Test
|
||||||
fun testCopying(): Unit = NoaFloat {
|
fun testCopying() = NoaFloat {
|
||||||
withCuda { device ->
|
withCuda { device ->
|
||||||
testingCopying(device)
|
testingCopying(device)
|
||||||
}
|
}
|
||||||
}!!
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testRequiresGrad() = NoaFloat {
|
||||||
|
val tensor = randNormal(intArrayOf(3))
|
||||||
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
assertTrue(tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = false
|
||||||
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
val detachedTensor = tensor.detachFromGraph()
|
||||||
|
assertTrue(!detachedTensor.requiresGrad)
|
||||||
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testTypeMoving() = NoaFloat {
|
||||||
|
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).asInt()
|
||||||
|
NoaInt {
|
||||||
|
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||||
|
tensorInt swap temporalTensor
|
||||||
|
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||||
|
}
|
||||||
|
assertTrue(tensorInt.asFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||||
|
}!!
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testViewWithNoCopy() = NoaInt {
|
||||||
|
withCuda { device ->
|
||||||
|
testingViewWithNoCopy(device)
|
||||||
|
}
|
||||||
|
}!!
|
||||||
|
|
||||||
}
|
}
|
@ -53,7 +53,7 @@ class TestUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSetSeed(): Unit = NoaFloat {
|
fun testSetSeed() = NoaFloat {
|
||||||
withCuda { device ->
|
withCuda { device ->
|
||||||
testingSetSeed(device)
|
testingSetSeed(device)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user