tensors testing

This commit is contained in:
Roland Grinis 2021-07-13 13:56:34 +01:00
parent face60824d
commit ea6cd01b89
2 changed files with 43 additions and 4 deletions

View File

@ -29,12 +29,51 @@ internal fun NoaFloat.testingCopying(device: Device = Device.CPU): Unit {
}
}
internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
val viewTensor = tensor.view(intArrayOf(2, 3))
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
viewTensor[intArrayOf(0, 0)] = 10
assertEquals(tensor[intArrayOf(0)], 10)
}
class TestTensor {
@Test
fun testCopying(): Unit = NoaFloat {
fun testCopying() = NoaFloat {
withCuda { device ->
testingCopying(device)
}
}!!
@Test
fun testRequiresGrad() = NoaFloat {
val tensor = randNormal(intArrayOf(3))
assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true
assertTrue(tensor.requiresGrad)
tensor.requiresGrad = false
assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true
val detachedTensor = tensor.detachFromGraph()
assertTrue(!detachedTensor.requiresGrad)
}!!
@Test
fun testTypeMoving() = NoaFloat {
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).asInt()
NoaInt {
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
tensorInt swap temporalTensor
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
}
assertTrue(tensorInt.asFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
}!!
@Test
fun testViewWithNoCopy() = NoaInt {
withCuda { device ->
testingViewWithNoCopy(device)
}
}!!
}

View File

@ -53,7 +53,7 @@ class TestUtils {
}
@Test
fun testSetSeed(): Unit = NoaFloat {
fun testSetSeed() = NoaFloat {
withCuda { device ->
testingSetSeed(device)
}