Add withChecks to most tests

This commit is contained in:
Roland Grinis 2021-01-20 14:01:45 +00:00
parent 391eb28cad
commit c9dfb6a08c
4 changed files with 41 additions and 19 deletions

View File

@ -188,3 +188,7 @@ tasks {
systemProperty("java.library.path", cppBuildDir.toString()) systemProperty("java.library.path", cppBuildDir.toString())
} }
} }
// No JS implementation
project.gradle.startParameter.excludedTaskNames.add("jsTest")
project.gradle.startParameter.excludedTaskNames.add("jsBrowserTest")

View File

@ -6,15 +6,19 @@ import kotlin.test.*
internal class TestAutograd { internal class TestAutograd {
@Test @Test
fun testAutoGrad() = TorchTensorFloatAlgebra { fun testAutoGrad() = TorchTensorFloatAlgebra {
withCuda { device -> withChecks {
testingAutoGrad(device) withCuda { device ->
testingAutoGrad(device)
}
} }
} }
@Test @Test
fun testBatchedAutoGrad() = TorchTensorFloatAlgebra { fun testBatchedAutoGrad() = TorchTensorFloatAlgebra {
withCuda { device -> withChecks {
testingBatchedAutoGrad(device) withCuda { device ->
testingBatchedAutoGrad(device)
}
} }
} }
} }

View File

@ -47,8 +47,10 @@ class TestTorchTensor {
@Test @Test
fun testViewWithNoCopy() = TorchTensorIntAlgebra { fun testViewWithNoCopy() = TorchTensorIntAlgebra {
withCuda { withChecks {
device -> testingViewWithNoCopy(device) withCuda {
device -> testingViewWithNoCopy(device)
}
} }
} }

View File

@ -7,43 +7,55 @@ internal class TestTorchTensorAlgebra {
@Test @Test
fun testScalarProduct() = TorchTensorRealAlgebra { fun testScalarProduct() = TorchTensorRealAlgebra {
withCuda { device -> withChecks {
testingScalarProduct(device) withCuda { device ->
testingScalarProduct(device)
}
} }
} }
@Test @Test
fun testMatrixMultiplication() = TorchTensorRealAlgebra { fun testMatrixMultiplication() = TorchTensorRealAlgebra {
withCuda { device -> withChecks {
testingMatrixMultiplication(device) withCuda { device ->
testingMatrixMultiplication(device)
}
} }
} }
@Test @Test
fun testLinearStructure() = TorchTensorRealAlgebra { fun testLinearStructure() = TorchTensorRealAlgebra {
withCuda { device -> withChecks {
testingLinearStructure(device) withCuda { device ->
testingLinearStructure(device)
}
} }
} }
@Test @Test
fun testTensorTransformations() = TorchTensorRealAlgebra { fun testTensorTransformations() = TorchTensorRealAlgebra {
withCuda { device -> withChecks {
testingTensorTransformations(device) withCuda { device ->
testingTensorTransformations(device)
}
} }
} }
@Test @Test
fun testBatchedSVD() = TorchTensorRealAlgebra { fun testBatchedSVD() = TorchTensorRealAlgebra {
withCuda { device -> withChecks {
testingBatchedSVD(device) withCuda { device ->
testingBatchedSVD(device)
}
} }
} }
@Test @Test
fun testBatchedSymEig() = TorchTensorRealAlgebra { fun testBatchedSymEig() = TorchTensorRealAlgebra {
withCuda { device -> withChecks {
testingBatchedSymEig(device) withCuda { device ->
testingBatchedSymEig(device)
}
} }
} }