diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index a0717169aa..9ed13d292c 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -163,7 +163,6 @@ def write_metrics_reports( with open(os.path.join(save_dir, "metrics.csv"), "w") as f: for k, v in metrics.items(): f.write(f"{k}{deli}{str(v)}\n") - if metric_details is not None and len(metric_details) > 0: for k, v in metric_details.items(): if isinstance(v, torch.Tensor): diff --git a/tests/test_distcall.py b/tests/test_distcall.py new file mode 100644 index 0000000000..1830a85654 --- /dev/null +++ b/tests/test_distcall.py @@ -0,0 +1,29 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tests.utils import DistCall, DistTestCase + + +class DistributedCallTest(DistTestCase): + def test_constructor(self): + with self.assertRaises(ValueError): + DistCall(nnodes=1, nproc_per_node=0) + with self.assertRaises(ValueError): + DistCall(nnodes=0, nproc_per_node=0) + with self.assertRaises(ValueError): + DistCall(nnodes=0, nproc_per_node=1) + _ = DistCall(nnodes=1, nproc_per_node=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index dfdaa16526..0868ec5ff3 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -12,7 +12,6 @@ import csv import os -import random import tempfile import unittest @@ -28,85 +27,82 @@ class DistributedMetricsSaver(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_content(self): - self._run() - - def _run(self): with tempfile.TemporaryDirectory() as tempdir: - metrics_saver = MetricsSaver( - save_dir=tempdir, - metrics=["metric1", "metric2"], - metric_details=["metric3", "metric4"], - batch_transform=lambda x: x["image_meta_dict"], - summary_ops="*", - ) - - def _val_func(engine, batch): - pass - - engine = Engine(_val_func) - - # test the case that all_gather with string length > 1024 chars - filename_postfix = "abcdefghigklmnopqrstuvwxyz" - for _ in range(1100): - filename_postfix += filename_postfix[random.randint(0, 26)] - - if dist.get_rank() == 0: - data = [{"image_meta_dict": {"filename_or_obj": [f"1{filename_postfix}"]}}] - - @engine.on(Events.EPOCH_COMPLETED) - def _save_metrics0(engine): - engine.state.metrics = {"metric1": 1, "metric2": 2} - engine.state.metric_details = { - "metric3": torch.tensor([[1, 2]]), - "metric4": torch.tensor([[5, 6]]), - } - - if dist.get_rank() == 1: - # different ranks have different data length - data = [ - {"image_meta_dict": {"filename_or_obj": [f"2{filename_postfix}"]}}, - {"image_meta_dict": {"filename_or_obj": [f"3{filename_postfix}"]}}, - ] - - @engine.on(Events.EPOCH_COMPLETED) - def _save_metrics1(engine): - engine.state.metrics = {"metric1": 1, "metric2": 2} - engine.state.metric_details = { - "metric3": torch.tensor([[2, 3], [3, 4]]), - "metric4": torch.tensor([[6, 7], [7, 8]]), - } - - metrics_saver.attach(engine) - engine.run(data, max_epochs=1) - - if dist.get_rank() == 0: - # check the metrics.csv and content - self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) - with open(os.path.join(tempdir, "metrics.csv")) as f: - f_csv = csv.reader(f) - for i, row in enumerate(f_csv): - self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) - self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) - # check the metric_raw.csv and content - with open(os.path.join(tempdir, "metric3_raw.csv")) as f: - f_csv = csv.reader(f) - for i, row in enumerate(f_csv): - if i > 0: - expected = [f"{i}{filename_postfix[0: 1023]}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"] - self.assertEqual(row, expected) - self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) - # check the metric_summary.csv and content - with open(os.path.join(tempdir, "metric3_summary.csv")) as f: - f_csv = csv.reader(f) - for i, row in enumerate(f_csv): - if i == 1: - self.assertEqual(row, ["class0\t1.0000\t1.0000\t1.0000\t1.0000\t1.0000\t0.0000"]) - elif i == 2: - self.assertEqual(row, ["class1\t2.0000\t2.0000\t2.0000\t2.0000\t2.0000\t0.0000"]) - elif i == 3: - self.assertEqual(row, ["mean\t1.5000\t1.5000\t1.5000\t1.5000\t1.5000\t0.0000"]) - self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) - self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) + self._run(tempdir) + + def _run(self, tempdir): + fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302] + + metrics_saver = MetricsSaver( + save_dir=tempdir, + metrics=["metric1", "metric2"], + metric_details=["metric3", "metric4"], + batch_transform=lambda x: x["image_meta_dict"], + summary_ops="*", + ) + + def _val_func(engine, batch): + pass + + engine = Engine(_val_func) + + if dist.get_rank() == 0: + data = [{"image_meta_dict": {"filename_or_obj": [fnames[0]]}}] + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics0(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[1, 2]]), + "metric4": torch.tensor([[5, 6]]), + } + + if dist.get_rank() == 1: + # different ranks have different data length + data = [ + {"image_meta_dict": {"filename_or_obj": [fnames[1]]}}, + {"image_meta_dict": {"filename_or_obj": [fnames[2]]}}, + ] + + @engine.on(Events.EPOCH_COMPLETED) + def _save_metrics1(engine): + engine.state.metrics = {"metric1": 1, "metric2": 2} + engine.state.metric_details = { + "metric3": torch.tensor([[2, 3], [3, 4]]), + "metric4": torch.tensor([[6, 7], [7, 8]]), + } + + metrics_saver.attach(engine) + engine.run(data, max_epochs=1) + + if dist.get_rank() == 0: + # check the metrics.csv and content + self.assertTrue(os.path.exists(os.path.join(tempdir, "metrics.csv"))) + with open(os.path.join(tempdir, "metrics.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) + # check the metric_raw.csv and content + with open(os.path.join(tempdir, "metric3_raw.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i > 0: + expected = [f"{fnames[i-1]}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"] + self.assertEqual(row, expected) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) + # check the metric_summary.csv and content + with open(os.path.join(tempdir, "metric3_summary.csv")) as f: + f_csv = csv.reader(f) + for i, row in enumerate(f_csv): + if i == 1: + self.assertEqual(row, ["class0\t1.0000\t1.0000\t1.0000\t1.0000\t1.0000\t0.0000"]) + elif i == 2: + self.assertEqual(row, ["class1\t2.0000\t2.0000\t2.0000\t2.0000\t2.0000\t0.0000"]) + elif i == 3: + self.assertEqual(row, ["mean\t1.5000\t1.5000\t1.5000\t1.5000\t1.5000\t0.0000"]) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) + self.assertTrue(os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 4597a18fbd..8b367158b2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -237,7 +237,11 @@ def __init__( """ self.nnodes = int(nnodes) self.nproc_per_node = int(nproc_per_node) - self.node_rank = int(os.environ.get("NODE_RANK", "0")) if node_rank is None else node_rank + if self.nnodes < 1 or self.nproc_per_node < 1: + raise ValueError( + f"number of nodes and processes per node must be >= 1, got {self.nnodes} and {self.nproc_per_node}" + ) + self.node_rank = int(os.environ.get("NODE_RANK", "0")) if node_rank is None else int(node_rank) self.master_addr = master_addr self.master_port = np.random.randint(10000, 20000) if master_port is None else master_port @@ -286,11 +290,20 @@ def run_process(self, func, local_rank, args, kwargs, results): finally: os.environ.clear() os.environ.update(_env) - dist.destroy_process_group() + try: + dist.destroy_process_group() + except RuntimeError as e: + warnings.warn(f"While closing process group: {e}.") def __call__(self, obj): if not torch.distributed.is_available(): return unittest.skipIf(True, "Skipping distributed tests because not torch.distributed.is_available()")(obj) + if torch.cuda.is_available() and torch.cuda.device_count() < self.nproc_per_node: + return unittest.skipIf( + True, + f"Skipping distributed tests because it requires {self.nnodes} devices " + f"but got {torch.cuda.device_count()}", + )(obj) _cache_original_func(obj)