tensors testing
This commit is contained in:
parent
face60824d
commit
ea6cd01b89
@ -29,12 +29,51 @@ internal fun NoaFloat.testingCopying(device: Device = Device.CPU): Unit {
|
||||
}
|
||||
}
|
||||
|
||||
internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
|
||||
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
|
||||
val viewTensor = tensor.view(intArrayOf(2, 3))
|
||||
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
|
||||
viewTensor[intArrayOf(0, 0)] = 10
|
||||
assertEquals(tensor[intArrayOf(0)], 10)
|
||||
}
|
||||
|
||||
class TestTensor {
|
||||
@Test
|
||||
fun testCopying(): Unit = NoaFloat {
|
||||
fun testCopying() = NoaFloat {
|
||||
withCuda { device ->
|
||||
testingCopying(device)
|
||||
}
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testRequiresGrad() = NoaFloat {
|
||||
val tensor = randNormal(intArrayOf(3))
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
tensor.requiresGrad = true
|
||||
assertTrue(tensor.requiresGrad)
|
||||
tensor.requiresGrad = false
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
tensor.requiresGrad = true
|
||||
val detachedTensor = tensor.detachFromGraph()
|
||||
assertTrue(!detachedTensor.requiresGrad)
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testTypeMoving() = NoaFloat {
|
||||
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).asInt()
|
||||
NoaInt {
|
||||
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
|
||||
tensorInt swap temporalTensor
|
||||
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
|
||||
}
|
||||
assertTrue(tensorInt.asFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
|
||||
}!!
|
||||
|
||||
@Test
|
||||
fun testViewWithNoCopy() = NoaInt {
|
||||
withCuda { device ->
|
||||
testingViewWithNoCopy(device)
|
||||
}
|
||||
}!!
|
||||
|
||||
}
|
@ -53,7 +53,7 @@ class TestUtils {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSetSeed(): Unit = NoaFloat {
|
||||
fun testSetSeed() = NoaFloat {
|
||||
withCuda { device ->
|
||||
testingSetSeed(device)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user