Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tests/test_parallel_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,26 @@ def fake_data_stream():
while True:
yield torch.rand((10, 1, 64, 64)), torch.rand((10, 1, 64, 64))

def expect_failure_if_no_gpu(test):
if not torch.cuda.is_available():
return unittest.expectedFailure(test)
else:
return test


class TestParallelExecution(unittest.TestCase):
"""
Tests single GPU, multi GPU, and CPU execution with the Ignite supervised trainer.
"""

@expect_failure_if_no_gpu
def test_single_gpu(self):
net = torch.nn.Conv2d(1, 1, 3, padding=1)
opt = torch.optim.Adam(net.parameters(), 1e-3)
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [0])
trainer = create_multigpu_supervised_trainer(net, opt, fake_loss, [torch.device("cuda:0")])
trainer.run(fake_data_stream(), 2, 2)

@expect_failure_if_no_gpu
def test_multi_gpu(self):
net = torch.nn.Conv2d(1, 1, 3, padding=1)
opt = torch.optim.Adam(net.parameters(), 1e-3)
Expand Down