diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index ec9ec562cd..830c6a4f0d 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -17,6 +17,8 @@ import numpy as np import torch +from monai.utils import ImageMetaKey as Key + class CSVSaver: """ @@ -73,7 +75,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] meta_data: the meta data information corresponding to the data. """ - save_key = meta_data["filename_or_obj"] if meta_data else str(self._data_index) + save_key = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index db559f97f4..e699a0ce9b 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -18,6 +18,7 @@ from monai.data.nifti_writer import write_nifti from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils import ImageMetaKey as Key class NiftiSaver: @@ -95,7 +96,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] See Also :py:meth:`monai.data.nifti_writer.write_nifti` """ - filename = meta_data["filename_or_obj"] if meta_data else str(self._data_index) + filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 original_affine = meta_data.get("original_affine", None) if meta_data else None affine = meta_data.get("affine", None) if meta_data else None diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 8ed8b234f4..4c4c847824 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -16,6 +16,7 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename +from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode @@ -82,7 +83,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] :py:meth:`monai.data.png_writer.write_png` """ - filename = meta_data["filename_or_obj"] if meta_data else str(self._data_index) + filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 6b190518fb..81c65ed580 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -25,5 +25,11 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler -from .utils import evenly_divisible_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports +from .utils import ( + evenly_divisible_all_gather, + stopping_fn_from_loss, + stopping_fn_from_metric, + string_list_all_gather, + write_metrics_reports, +) from .validation_handler import ValidationHandler diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 6753cafcb0..a1c76dd338 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -13,8 +13,11 @@ from typing import TYPE_CHECKING, Callable, Optional from monai.data import CSVSaver +from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather +from monai.utils import ImageMetaKey as Key from monai.utils import exact_version, optional_import +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine @@ -25,6 +28,8 @@ class ClassificationSaver: """ Event handler triggered on completing every iteration to save the classification predictions as CSV file. + If running in distributed data parallel, only saves CSV file in the specified rank. + """ def __init__( @@ -35,6 +40,7 @@ def __init__( batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, + save_rank: int = 0, ) -> None: """ Args: @@ -49,8 +55,11 @@ def __init__( The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. + save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation, + default to 0. """ + self._expected_rank: bool = idist.get_rank() == save_rank self.saver = CSVSaver(output_dir, filename, overwrite) self.batch_transform = batch_transform self.output_transform = output_transform @@ -67,7 +76,7 @@ def attach(self, engine: Engine) -> None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) - if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): + if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize()) def __call__(self, engine: Engine) -> None: @@ -77,6 +86,12 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - meta_data = self.batch_transform(engine.state.batch) - engine_output = self.output_transform(engine.state.output) - self.saver.save_batch(engine_output, meta_data) + _meta_data = self.batch_transform(engine.state.batch) + if Key.FILENAME_OR_OBJ in _meta_data: + # all gather filenames across ranks + _meta_data[Key.FILENAME_OR_OBJ] = string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ]) + # all gather predictions across ranks + _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) + + if self._expected_rank: + self.saver.save_batch(_engine_output, _meta_data) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index d67f0f6c39..b9ea296821 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -11,9 +11,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union -from monai.handlers.utils import write_metrics_reports +from monai.handlers.utils import string_list_all_gather, write_metrics_reports +from monai.utils import ImageMetaKey as Key from monai.utils import ensure_tuple, exact_version, optional_import -from monai.utils.module import get_torch_version_tuple Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") @@ -93,7 +93,7 @@ def _started(self, engine: Engine) -> None: def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: - _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)["filename_or_obj"])) + _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ])) self._filenames += _filenames def __call__(self, engine: Engine) -> None: @@ -105,15 +105,8 @@ def __call__(self, engine: Engine) -> None: if self.save_rank >= ws: raise ValueError("target rank is greater than the distributed group size.") - _images = self._filenames - if ws > 1: - _filenames = self.deli.join(_images) - if get_torch_version_tuple() > (1, 6, 0): - # all gather across all processes - _filenames = self.deli.join(idist.all_gather(_filenames)) - else: - raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") - _images = _filenames.split(self.deli) + # all gather file names across ranks + _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames # only save metrics to file in specified rank if idist.get_rank() == self.save_rank: diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index a4b5c02f61..2165ad8860 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -11,12 +11,12 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union import numpy as np import torch -from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") if TYPE_CHECKING: @@ -28,6 +28,7 @@ "stopping_fn_from_metric", "stopping_fn_from_loss", "evenly_divisible_all_gather", + "string_list_all_gather", "write_metrics_reports", ] @@ -81,6 +82,29 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) +def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: + """ + Utility function for distributed data parallel to all gather a list of strings. + + Args: + strings: a list of strings to all gather. + delimiter: use the delimiter to join the string list to be a long string, + then all gather across ranks and split to a list. default to "\t". + + """ + if idist.get_world_size() <= 1: + return strings + + _joined = delimiter.join(strings) + if get_torch_version_tuple() > (1, 6, 0): + # all gather across all ranks + _joined = delimiter.join(idist.all_gather(_joined)) + else: + raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") + + return _joined.split(delimiter) + + def write_metrics_reports( save_dir: str, images: Optional[Sequence[str]], diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 772c7cf74f..f57b2dd27a 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -20,6 +20,7 @@ from monai.config import DtypeLike from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms.compose import Transform +from monai.utils import ImageMetaKey as Key from monai.utils import ensure_tuple, optional_import nib, _ = optional_import("nibabel") @@ -126,5 +127,5 @@ def __call__( if self.image_only: return img_array - meta_data["filename_or_obj"] = ensure_tuple(filename)[0] + meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] return img_array, meta_data diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index e5567f9f16..1e17d44029 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -32,6 +32,7 @@ ) from .misc import ( MAX_SEED, + ImageMetaKey, copy_to_device, dtype_numpy_to_torch, dtype_torch_to_numpy, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index c5e8318db3..f9346340cf 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -41,6 +41,7 @@ "dtype_numpy_to_torch", "MAX_SEED", "copy_to_device", + "ImageMetaKey", ] _seed = None @@ -349,3 +350,11 @@ def copy_to_device( warnings.warn(f"{fn_name} called with incompatible type: " + f"{type(obj)}. Data will be returned unchanged.") return obj + + +class ImageMetaKey: + """ + Common key names in the meta data header of images + """ + + FILENAME_OR_OBJ = "filename_or_obj" diff --git a/tests/min_tests.py b/tests/min_tests.py index 665ead6cc6..0fd6985067 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -103,6 +103,7 @@ def run_testsuit(): "test_handler_metrics_saver", "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", + "test_handler_classification_saver_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py new file mode 100644 index 0000000000..275d5b2231 --- /dev/null +++ b/tests/test_handler_classification_saver_dist.py @@ -0,0 +1,60 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +import tempfile +import unittest + +import numpy as np +import torch +import torch.distributed as dist +from ignite.engine import Engine + +from monai.handlers import ClassificationSaver +from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion + + +@SkipIfBeforePyTorchVersion((1, 7)) +class DistributedHandlerClassificationSaver(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + rank = dist.get_rank() + + # set up engine + def _train_func(engine, batch): + return torch.zeros(8 + rank * 2) + + engine = Engine(_train_func) + + # set up testing handler + saver = ClassificationSaver(output_dir=tempdir, filename="predictions.csv", save_rank=1) + saver.attach(engine) + + # rank 0 has 8 images, rank 1 has 10 images + data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))]}] + engine.run(data, max_epochs=1) + filepath = os.path.join(tempdir, "predictions.csv") + if rank == 1: + self.assertTrue(os.path.exists(filepath)) + with open(filepath, "r") as f: + reader = csv.reader(f) + i = 0 + for row in reader: + self.assertEqual(row[0], "testfile" + str(i)) + self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) + i += 1 + self.assertEqual(i, 18) + + +if __name__ == "__main__": + unittest.main()