diff --git a/fairscale/utils/testing.py b/fairscale/utils/testing.py index d902383ba..85e06c6a3 100644 --- a/fairscale/utils/testing.py +++ b/fairscale/utils/testing.py @@ -31,6 +31,7 @@ import multiprocessing import os import random +import sys import tempfile from typing import Any, Callable, Dict, List, Optional, Tuple @@ -53,6 +54,10 @@ not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required" ) +skip_if_py38 = pytest.mark.skipif( + sys.version_info.major == 3 and sys.version_info.minor == 8, reason="Python3.8 is skipped" +) + _, filename_mpi = tempfile.mkstemp() diff --git a/tests/nn/data_parallel/test_sharded_ddp.py b/tests/nn/data_parallel/test_sharded_ddp.py index 2277067a7..b70baf3d1 100644 --- a/tests/nn/data_parallel/test_sharded_ddp.py +++ b/tests/nn/data_parallel/test_sharded_ddp.py @@ -21,7 +21,7 @@ from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim import OSS -from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_single_gpu +from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_py38, skip_if_single_gpu def run_one_step(rank, world_size, backend, device, temp_file_name): @@ -112,16 +112,17 @@ def run_test(backend, device, world_size=2): mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) -def test_step_on_cpu(): - run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4) - - @skip_if_no_cuda @skip_if_single_gpu -def test_step_on_gpu(): +def test_step_gpu(): run_test(backend=dist.Backend.NCCL, device=torch.device("cuda")) +@skip_if_py38 +def test_step_cpu(): + run_test(backend=dist.Backend.GLOO, device=torch.device("cpu")) + + def run_ddp_parity(rank, world_size, backend, temp_file_name): url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)