Add withChecks to most tests

This commit is contained in:
Roland Grinis 2021-01-20 14:01:45 +00:00
parent 391eb28cad
commit c9dfb6a08c
4 changed files with 41 additions and 19 deletions

View File

@ -188,3 +188,7 @@ tasks {
systemProperty("java.library.path", cppBuildDir.toString()) systemProperty("java.library.path", cppBuildDir.toString())
} }
} }
// No JS implementation
project.gradle.startParameter.excludedTaskNames.add("jsTest")
project.gradle.startParameter.excludedTaskNames.add("jsBrowserTest")

View File

@ -6,15 +6,19 @@ import kotlin.test.*
internal class TestAutograd { internal class TestAutograd {
@Test @Test
fun testAutoGrad() = TorchTensorFloatAlgebra { fun testAutoGrad() = TorchTensorFloatAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingAutoGrad(device) testingAutoGrad(device)
} }
} }
}
@Test @Test
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra { fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingBatchedAutoGrad(device) testingBatchedAutoGrad(device)
} }
} }
} }
}

View File

@ -47,9 +47,11 @@ class TestTorchTensor {
@Test @Test
fun testViewWithNoCopy() = TorchTensorIntAlgebra { fun testViewWithNoCopy() = TorchTensorIntAlgebra {
withChecks {
withCuda { withCuda {
device -> testingViewWithNoCopy(device) device -> testingViewWithNoCopy(device)
} }
} }
}
} }

View File

@ -7,44 +7,56 @@ internal class TestTorchTensorAlgebra {
@Test @Test
fun testScalarProduct() = TorchTensorRealAlgebra { fun testScalarProduct() = TorchTensorRealAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingScalarProduct(device) testingScalarProduct(device)
} }
} }
}
@Test @Test
fun testMatrixMultiplication() = TorchTensorRealAlgebra { fun testMatrixMultiplication() = TorchTensorRealAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingMatrixMultiplication(device) testingMatrixMultiplication(device)
} }
} }
}
@Test @Test
fun testLinearStructure() = TorchTensorRealAlgebra { fun testLinearStructure() = TorchTensorRealAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingLinearStructure(device) testingLinearStructure(device)
} }
} }
}
@Test @Test
fun testTensorTransformations() = TorchTensorRealAlgebra { fun testTensorTransformations() = TorchTensorRealAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingTensorTransformations(device) testingTensorTransformations(device)
} }
} }
}
@Test @Test
fun testBatchedSVD() = TorchTensorRealAlgebra { fun testBatchedSVD() = TorchTensorRealAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingBatchedSVD(device) testingBatchedSVD(device)
} }
} }
}
@Test @Test
fun testBatchedSymEig() = TorchTensorRealAlgebra { fun testBatchedSymEig() = TorchTensorRealAlgebra {
withChecks {
withCuda { device -> withCuda { device ->
testingBatchedSymEig(device) testingBatchedSymEig(device)
} }
} }
}
} }