tensors testing

This commit is contained in:
Roland Grinis 2021-07-13 13:56:34 +01:00
parent face60824d
commit ea6cd01b89
2 changed files with 43 additions and 4 deletions

View File

@ -18,23 +18,62 @@ internal fun NoaFloat.testingCopying(device: Device = Device.CPU): Unit {
tensor[intArrayOf(1, 2, 3)] = 0.1f
assertTrue(copyOfTensor.copyToArray() contentEquals array)
assertEquals(0.1f, tensor[intArrayOf(1, 2, 3)])
if(device != Device.CPU){
if (device != Device.CPU) {
val normalCpu = randNormal(intArrayOf(2, 3))
val normalGpu = normalCpu.copyToDevice(device)
assertTrue(normalCpu.copyToArray() contentEquals normalGpu.copyToArray())
val uniformGpu = randUniform(intArrayOf(3,2),device)
val uniformGpu = randUniform(intArrayOf(3, 2), device)
val uniformCpu = uniformGpu.copyToDevice(Device.CPU)
assertTrue(uniformGpu.copyToArray() contentEquals uniformCpu.copyToArray())
}
}
internal fun NoaInt.testingViewWithNoCopy(device: Device = Device.CPU) {
val tensor = copyFromArray(intArrayOf(1, 2, 3, 4, 5, 6), shape = intArrayOf(6), device)
val viewTensor = tensor.view(intArrayOf(2, 3))
assertTrue(viewTensor.shape contentEquals intArrayOf(2, 3))
viewTensor[intArrayOf(0, 0)] = 10
assertEquals(tensor[intArrayOf(0)], 10)
}
class TestTensor {
@Test
fun testCopying(): Unit = NoaFloat {
fun testCopying() = NoaFloat {
withCuda { device ->
testingCopying(device)
}
}!!
@Test
fun testRequiresGrad() = NoaFloat {
val tensor = randNormal(intArrayOf(3))
assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true
assertTrue(tensor.requiresGrad)
tensor.requiresGrad = false
assertTrue(!tensor.requiresGrad)
tensor.requiresGrad = true
val detachedTensor = tensor.detachFromGraph()
assertTrue(!detachedTensor.requiresGrad)
}!!
@Test
fun testTypeMoving() = NoaFloat {
val tensorInt = copyFromArray(floatArrayOf(1f, 2f, 3f), intArrayOf(3)).asInt()
NoaInt {
val temporalTensor = copyFromArray(intArrayOf(4, 5, 6), intArrayOf(3))
tensorInt swap temporalTensor
assertTrue(temporalTensor.copyToArray() contentEquals intArrayOf(1, 2, 3))
}
assertTrue(tensorInt.asFloat().copyToArray() contentEquals floatArrayOf(4f, 5f, 6f))
}!!
@Test
fun testViewWithNoCopy() = NoaInt {
withCuda { device ->
testingViewWithNoCopy(device)
}
}!!
}

View File

@ -53,7 +53,7 @@ class TestUtils {
}
@Test
fun testSetSeed(): Unit = NoaFloat {
fun testSetSeed() = NoaFloat {
withCuda { device ->
testingSetSeed(device)
}