From 4e71bf2b22730e9c1a8967336d3a51d5a700f814 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 25 Nov 2020 23:24:27 +0000 Subject: [PATCH 1/4] fixes 926 Signed-off-by: Wenqi Li --- monai/handlers/confusion_matrix.py | 16 +-- tests/test_handler_confusion_matrix_dist.py | 35 +++--- tests/test_handler_rocauc_dist.py | 44 ++++---- tests/utils.py | 114 ++++++++++++++++++++ 4 files changed, 159 insertions(+), 50 deletions(-) diff --git a/monai/handlers/confusion_matrix.py b/monai/handlers/confusion_matrix.py index fbba7fcce1..7bb68a25fd 100644 --- a/monai/handlers/confusion_matrix.py +++ b/monai/handlers/confusion_matrix.py @@ -67,19 +67,19 @@ def __init__( self._num_examples = 0 self.compute_sample = compute_sample self.metric_name = metric_name - self._total_tp = 0 - self._total_fp = 0 - self._total_tn = 0 - self._total_fn = 0 + self._total_tp = 0.0 + self._total_fp = 0.0 + self._total_tn = 0.0 + self._total_fn = 0.0 @reinit__is_reduced def reset(self) -> None: self._sum = 0.0 self._num_examples = 0 - self._total_tp = 0 - self._total_fp = 0 - self._total_tn = 0 - self._total_fn = 0 + self._total_tp = 0.0 + self._total_fp = 0.0 + self._total_tn = 0.0 + self._total_fn = 0.0 @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index c898148305..70364ca07c 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -10,18 +10,27 @@ # limitations under the License. +import unittest + import numpy as np import torch import torch.distributed as dist from monai.handlers import ConfusionMatrix +from tests.utils import DistCall + +class DistributedConfusionMatrix(unittest.TestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_compute_sample(self): + self._compute(True) -def main(): - for compute_sample in [True, False]: - dist.init_process_group(backend="nccl", init_method="env://") + @DistCall(nnodes=1, nproc_per_node=2) + def test_compute(self): + self._compute(False) - torch.cuda.set_device(dist.get_rank()) + def _compute(self, compute_sample=True): + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" metric = ConfusionMatrix(include_background=True, metric_name="tpr", compute_sample=compute_sample) if dist.get_rank() == 0: @@ -30,25 +39,25 @@ def main(): [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]], ], - device=torch.device("cuda:0"), + device=device, ) y = torch.tensor( [ [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]], [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]], ], - device=torch.device("cuda:0"), + device=device, ) metric.update([y_pred, y]) if dist.get_rank() == 1: y_pred = torch.tensor( [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [0.0, 0.0]]]], - device=torch.device("cuda:1"), + device=device, ) y = torch.tensor( [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], - device=torch.device("cuda:1"), + device=device, ) metric.update([y_pred, y]) @@ -59,14 +68,6 @@ def main(): else: np.testing.assert_allclose(avg_metric, 0.8333, rtol=1e-04, atol=1e-04) - dist.destroy_process_group() - - -# suppose to execute on 2 rank processes -# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE -# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE -# --master_addr="192.168.1.1" --master_port=1234 -# test_handler_confusion_matrix_dist.py if __name__ == "__main__": - main() + unittest.main() diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index 13d141dc73..a91f6a950b 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -10,40 +10,34 @@ # limitations under the License. +import unittest + import numpy as np import torch import torch.distributed as dist from monai.handlers import ROCAUC +from tests.utils import DistCall -def main(): - dist.init_process_group(backend="nccl", init_method="env://") - - torch.cuda.set_device(dist.get_rank()) - auc_metric = ROCAUC(to_onehot_y=True, softmax=True) - - if dist.get_rank() == 0: - y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=torch.device("cuda:0")) - y = torch.tensor([[0], [1]], device=torch.device("cuda:0")) - auc_metric.update([y_pred, y]) - - if dist.get_rank() == 1: - y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=torch.device("cuda:1")) - y = torch.tensor([[0], [1]], device=torch.device("cuda:1")) - auc_metric.update([y_pred, y]) - - result = auc_metric.compute() - np.testing.assert_allclose(0.75, result) +class DistributedROCAUC(unittest.TestCase): + @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + def test_compute(self): + auc_metric = ROCAUC(to_onehot_y=True, softmax=True) + device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" + if dist.get_rank() == 0: + y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device) + y = torch.tensor([[0], [1]], device=device) + auc_metric.update([y_pred, y]) - dist.destroy_process_group() + if dist.get_rank() == 1: + y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5]], device=device) + y = torch.tensor([[0], [1]], device=device) + auc_metric.update([y_pred, y]) + result = auc_metric.compute() + np.testing.assert_allclose(0.75, result) -# suppose to execute on 2 rank processes -# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE -# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE -# --master_addr="192.168.1.1" --master_port=1234 -# test_handler_rocauc_dist.py if __name__ == "__main__": - main() + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index cbfa5194b0..68a02a7fac 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,16 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import os import sys import tempfile import unittest from io import BytesIO from subprocess import PIPE, Popen +from typing import Optional from urllib.error import ContentTooShortError, HTTPError, URLError import numpy as np import torch +import torch.distributed as dist from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import optional_import, set_determinism @@ -87,6 +90,117 @@ def make_nifti_image(array, affine=None): return image_name +class DistCall: + """ + Wrap a test case so that it will run in multiple processes on a single machine using `torch.distributed`. + + Usage: + + decorate a unittest testcase method with a `DistCall` instance:: + + class MyTests(unittest.TestCase): + @DistCall(nnodes=1, nproc_per_node=3, master_addr="localhost") + def test_compute(self): + ... + + the `test_compute` method should trigger different worker logic according to `dist.get_rank()`. + + Multi-node tests require a fixed master_addr:master_port, with node_rank set manually in multiple scripts + or from environment variable "NODE_RANK". + """ + + def __init__( + self, + nnodes: int = 1, + nproc_per_node: int = 1, + master_addr: str = "localhost", + master_port: Optional[int] = None, + node_rank: Optional[int] = None, + timeout=1000, + init_method=None, + backend: Optional[str] = None, + verbose: bool = False, + ): + """ + + Args: + nnodes: The number of nodes to use for distributed call. + nproc_per_node: The number of processes to call on each node. + master_addr: Master node (rank 0)'s address, should be either the IP address or the hostname of node 0. + master_port: Master node (rank 0)'s free port. + node_rank: The rank of the node, this could be set via environment variable "NODE_RANK". + timeout: Timeout for operations executed against the process group. + init_method: URL specifying how to initialize the process group. Default is "env://" if unspecified. + backend: The backend to use. Depending on build-time configurations, + valid values include ``mpi``, ``gloo``, and ``nccl``. + verbose: whether to print NCCL debug info. + """ + self.nnodes = int(nnodes) + self.nproc_per_node = int(nproc_per_node) + self.node_rank = int(os.environ.get("NODE_RANK", "0")) if node_rank is None else node_rank + self.master_addr = master_addr + self.master_port = np.random.randint(10000, 20000) if master_port is None else master_port + + if backend is None: + self.backend = "nccl" if torch.distributed.is_nccl_available() else "gloo" + else: + self.backend = backend + self.init_method = init_method + self.timeout = datetime.timedelta(0, timeout) + self.verbose = verbose + + def run_process(self, fn, local_rank, instance, args, kwargs, results): + _env = os.environ.copy() # keep the original system env + try: + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["LOCAL_RANK"] = str(local_rank) + if self.verbose: + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_SUBSYS"] = "ALL" + os.environ["NCCL_BLOCKING_WAIT"] = str(1) + os.environ["OMP_NUM_THREADS"] = str(1) + os.environ["WORLD_SIZE"] = str(self.nproc_per_node * self.nnodes) + os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank) + + if torch.cuda.is_available(): + torch.cuda.set_device(int(local_rank)) + + dist.init_process_group( + backend=self.backend, + init_method=self.init_method, + timeout=self.timeout, + ) + fn(instance, *args, **kwargs) + results.put(True) + except Exception as e: + results.put(False) + raise e + finally: + os.environ.clear() + os.environ.update(_env) + dist.destroy_process_group() + + def __call__(self, obj): + if not torch.distributed.is_available(): + return unittest.skipIf(True, "Skipping distributed tests because not torch.distributed.is_available()")(obj) + + def _wrapper(cls_inst, *args, **kwargs): + processes = [] + results = torch.multiprocessing.Queue() + for proc_rank in range(self.nproc_per_node): + p = torch.multiprocessing.Process( + target=self.run_process, args=(obj, proc_rank, cls_inst, args, kwargs, results) + ) + p.start() + processes.append(p) + for p in processes: + p.join() + assert results.get(), "Distributed call failed." + + return _wrapper + + class NumpyImageTestCase2D(unittest.TestCase): im_shape = (128, 64) input_channels = 1 From f1dc5187cf655cc13832cd3797775277c6a8f89f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 26 Nov 2020 00:18:50 +0000 Subject: [PATCH 2/4] fixes tests - compatibility Signed-off-by: Wenqi Li --- tests/test_handler_confusion_matrix_dist.py | 5 +- tests/test_handler_rocauc_dist.py | 5 +- tests/utils.py | 66 +++++++++++++++++++-- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index 70364ca07c..22f7f10741 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -17,11 +17,12 @@ import torch.distributed as dist from monai.handlers import ConfusionMatrix -from tests.utils import DistCall +from tests.utils import DistCall, DistTestCase, skip_if_windows -class DistributedConfusionMatrix(unittest.TestCase): +class DistributedConfusionMatrix(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) + @skip_if_windows def test_compute_sample(self): self._compute(True) diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index a91f6a950b..ca8ffdaf56 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -17,11 +17,12 @@ import torch.distributed as dist from monai.handlers import ROCAUC -from tests.utils import DistCall +from tests.utils import DistCall, DistTestCase, skip_if_windows -class DistributedROCAUC(unittest.TestCase): +class DistributedROCAUC(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) + @skip_if_windows def test_compute(self): auc_metric = ROCAUC(to_onehot_y=True, softmax=True) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" diff --git a/tests/utils.py b/tests/utils.py index 68a02a7fac..c5ba76cdab 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,6 +10,8 @@ # limitations under the License. import datetime +import functools +import importlib import os import sys import tempfile @@ -90,6 +92,30 @@ def make_nifti_image(array, affine=None): return image_name +class DistTestCase(unittest.TestCase): + """testcase without _outcome, so that it's picklable.""" + + original_mp = None + + def setUp(self) -> None: + self.original_mp = torch.multiprocessing.get_start_method(allow_none=True) + try: + torch.multiprocessing.set_start_method("spawn", force=True) + except RuntimeError: + pass + + def tearDown(self) -> None: + try: + torch.multiprocessing.set_start_method(str(self.original_mp), force=True) + except RuntimeError: + pass + + def __getstate__(self): + self_dict = self.__dict__.copy() + del self_dict["_outcome"] + return self_dict + + class DistCall: """ Wrap a test case so that it will run in multiple processes on a single machine using `torch.distributed`. @@ -116,7 +142,7 @@ def __init__( master_addr: str = "localhost", master_port: Optional[int] = None, node_rank: Optional[int] = None, - timeout=1000, + timeout=60, init_method=None, backend: Optional[str] = None, verbose: bool = False, @@ -142,14 +168,16 @@ def __init__( self.master_port = np.random.randint(10000, 20000) if master_port is None else master_port if backend is None: - self.backend = "nccl" if torch.distributed.is_nccl_available() else "gloo" + self.backend = "nccl" if torch.distributed.is_nccl_available() and torch.cuda.is_available() else "gloo" else: self.backend = backend self.init_method = init_method + if self.init_method is None and sys.platform == "win32": + self.init_method = "file:///d:/a_temp" self.timeout = datetime.timedelta(0, timeout) self.verbose = verbose - def run_process(self, fn, local_rank, instance, args, kwargs, results): + def run_process(self, func, local_rank, args, kwargs, results): _env = os.environ.copy() # keep the original system env try: os.environ["MASTER_ADDR"] = self.master_addr @@ -170,8 +198,10 @@ def run_process(self, fn, local_rank, instance, args, kwargs, results): backend=self.backend, init_method=self.init_method, timeout=self.timeout, + world_size=int(os.environ["WORLD_SIZE"]), + rank=int(os.environ["RANK"]), ) - fn(instance, *args, **kwargs) + func(*args, **kwargs) results.put(True) except Exception as e: results.put(False) @@ -185,12 +215,17 @@ def __call__(self, obj): if not torch.distributed.is_available(): return unittest.skipIf(True, "Skipping distributed tests because not torch.distributed.is_available()")(obj) - def _wrapper(cls_inst, *args, **kwargs): + _cache_original_func(obj) + + @functools.wraps(obj) + def _wrapper(*args, **kwargs): processes = [] results = torch.multiprocessing.Queue() + func = _call_original_func + args = [obj.__name__, obj.__module__] + list(args) for proc_rank in range(self.nproc_per_node): p = torch.multiprocessing.Process( - target=self.run_process, args=(obj, proc_rank, cls_inst, args, kwargs, results) + target=self.run_process, args=(func, proc_rank, args, kwargs, results) ) p.start() processes.append(p) @@ -201,6 +236,25 @@ def _wrapper(cls_inst, *args, **kwargs): return _wrapper +_original_funcs = {} + + +def _cache_original_func(obj) -> None: + """cache the original function by name, so that the decorator doesn't shadow it.""" + global _original_funcs + _original_funcs[obj.__name__] = obj + + +def _call_original_func(name, module, *args, **kwargs): + if name not in _original_funcs: + _original_module = importlib.import_module(module) # reimport, refresh _original_funcs + if not hasattr(_original_module, name): + # refresh module doesn't work + raise RuntimeError(f"Could not recover the original {name} from {module}: {_original_funcs}.") + f = _original_funcs[name] + return f(*args, **kwargs) + + class NumpyImageTestCase2D(unittest.TestCase): im_shape = (128, 64) input_channels = 1 From 3626c1d3c2efd1a861ffe3b379b69ad73b125daa Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 26 Nov 2020 11:33:13 +0000 Subject: [PATCH 3/4] update windows tests Signed-off-by: Wenqi Li --- tests/test_handler_confusion_matrix_dist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index 22f7f10741..f991e80aaa 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -27,6 +27,7 @@ def test_compute_sample(self): self._compute(True) @DistCall(nnodes=1, nproc_per_node=2) + @skip_if_windows def test_compute(self): self._compute(False) From da4092b026a34013a86b92e18586f9b6ded46b3e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 26 Nov 2020 17:15:43 +0000 Subject: [PATCH 4/4] workaround pagefile config Signed-off-by: Wenqi Li --- .github/workflows/pythonapp.yml | 9 ++++++++- tests/test_handler_confusion_matrix_dist.py | 4 +--- tests/test_handler_rocauc_dist.py | 3 +-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 2942774cc6..35ee5433cd 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -49,6 +49,13 @@ jobs: os: [windows-latest, macOS-latest, ubuntu-latest] timeout-minutes: 60 steps: + - if: runner.os == 'windows' + name: Config pagefile (Windows only) + uses: al-cheb/configure-pagefile-action@v1.2 + with: + minimum-size: 8 + maximum-size: 16 + disk-root: "D:" - uses: actions/checkout@v2 - name: Set up Python 3.8 uses: actions/setup-python@v2 @@ -73,7 +80,7 @@ jobs: - if: runner.os == 'windows' name: Install torch cpu from pytorch.org (Windows only) run: | - python -m pip install torch==1.7.0 torchvision==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html # min. requirements for windows instances python -c "f=open('requirements-dev.txt', 'r'); txt=f.readlines(); f.close(); print(txt); f=open('requirements-dev.txt', 'w'); f.writelines(txt[1:12]); f.close()" - name: Install the dependencies diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py index f991e80aaa..b7718e15d2 100644 --- a/tests/test_handler_confusion_matrix_dist.py +++ b/tests/test_handler_confusion_matrix_dist.py @@ -17,17 +17,15 @@ import torch.distributed as dist from monai.handlers import ConfusionMatrix -from tests.utils import DistCall, DistTestCase, skip_if_windows +from tests.utils import DistCall, DistTestCase class DistributedConfusionMatrix(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) - @skip_if_windows def test_compute_sample(self): self._compute(True) @DistCall(nnodes=1, nproc_per_node=2) - @skip_if_windows def test_compute(self): self._compute(False) diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index ca8ffdaf56..7ff45185a6 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -17,12 +17,11 @@ import torch.distributed as dist from monai.handlers import ROCAUC -from tests.utils import DistCall, DistTestCase, skip_if_windows +from tests.utils import DistCall, DistTestCase class DistributedROCAUC(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2, node_rank=0) - @skip_if_windows def test_compute(self): auc_metric = ROCAUC(to_onehot_y=True, softmax=True) device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"