From 841d7ad89540ae2cb3df154f6957cfd01ae0c439 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 15 Dec 2020 18:27:53 +0000 Subject: [PATCH] part of #926 Signed-off-by: Wenqi Li --- tests/test_distributed_sampler.py | 47 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/tests/test_distributed_sampler.py b/tests/test_distributed_sampler.py index 8c182dd9e6..06760e708f 100644 --- a/tests/test_distributed_sampler.py +++ b/tests/test_distributed_sampler.py @@ -9,39 +9,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np -import torch import torch.distributed as dist from monai.data import DistributedSampler +from tests.utils import DistCall, DistTestCase -def test(expected, **kwargs): - dist.init_process_group(backend="nccl", init_method="env://") - - torch.cuda.set_device(dist.get_rank()) - data = [1, 2, 3, 4, 5] - sampler = DistributedSampler(dataset=data, **kwargs) - samples = np.array([data[i] for i in list(sampler)]) - if dist.get_rank() == 0: - np.testing.assert_allclose(samples, np.array(expected[0])) - - if dist.get_rank() == 1: - np.testing.assert_allclose(samples, np.array(expected[1])) - - dist.destroy_process_group() +class DistributedSamplerTest(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_even(self): + data = [1, 2, 3, 4, 5] + sampler = DistributedSampler(dataset=data, shuffle=False) + samples = np.array([data[i] for i in list(sampler)]) + if dist.get_rank() == 0: + np.testing.assert_allclose(samples, np.array([1, 3, 5])) + if dist.get_rank() == 1: + np.testing.assert_allclose(samples, np.array([2, 4, 1])) -def main(): - test(shuffle=False, expected=[[1, 3, 5], [2, 4, 1]]) - test(shuffle=False, even_divisible=False, expected=[[1, 3, 5], [2, 4]]) + @DistCall(nnodes=1, nproc_per_node=2) + def test_uneven(self): + data = [1, 2, 3, 4, 5] + sampler = DistributedSampler(dataset=data, shuffle=False, even_divisible=False) + samples = np.array([data[i] for i in list(sampler)]) + if dist.get_rank() == 0: + np.testing.assert_allclose(samples, np.array([1, 3, 5])) + if dist.get_rank() == 1: + np.testing.assert_allclose(samples, np.array([2, 4])) -# suppose to execute on 2 rank processes -# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE -# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE -# --master_addr="localhost" --master_port=1234 -# test_distributed_sampler.py if __name__ == "__main__": - main() + unittest.main()