diff --git a/tests/test_parallel_execution.py b/tests/test_parallel_execution.py index f8e0e07d4a..0ef0dccd2a 100644 --- a/tests/test_parallel_execution.py +++ b/tests/test_parallel_execution.py @@ -25,18 +25,26 @@ def fake_data_stream(): while True: yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64)) +def expect_failure_if_no_gpu(test): + if not torch.cuda.is_available(): + return unittest.expectedFailure(test) + else: + return test + class TestParallelExecution(unittest.TestCase): """ Tests single GPU, multi GPU, and CPU execution with the Ignite supervised trainer. """ + @expect_failure_if_no_gpu def test_single_gpu(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) opt = torch.optim.Adam(net.parameters(), 1e-3) - trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [0]) + trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [torch.device("cuda:0")]) trainer.run(fake_data_stream(), 2, 2) + @expect_failure_if_no_gpu def test_multi_gpu(self): net = torch.nn.Conv2d(1, 1, 3, padding=1) opt = torch.optim.Adam(net.parameters(), 1e-3)