diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 75806ce120..5bccaba8a2 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -15,7 +15,7 @@ from .aliases import alias, resolve_name from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default -from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather +from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather from .enums import ( Average, BlendMode, diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 546058c93e..c476ace73b 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -12,6 +12,9 @@ from __future__ import annotations import sys +import warnings +from collections.abc import Callable +from logging import Filter if sys.version_info >= (3, 8): from typing import Literal @@ -26,7 +29,7 @@ idist, has_ignite = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") -__all__ = ["get_dist_device", "evenly_divisible_all_gather", "string_list_all_gather"] +__all__ = ["get_dist_device", "evenly_divisible_all_gather", "string_list_all_gather", "RankFilter"] def get_dist_device(): @@ -174,3 +177,31 @@ def string_list_all_gather(strings: list[str], delimiter: str = "\t") -> list[st _gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered] return [i for k in _gathered for i in k] + + +class RankFilter(Filter): + """ + The RankFilter class is a convenient filter that extends the Filter class in the Python logging module. + The purpose is to control which log records are processed based on the rank in a distributed environment. + + Args: + rank: the rank of the process in the torch.distributed. Default is None and then it will use dist.get_rank(). + filter_fn: an optional lambda function used as the filtering criteria. + The default function logs only if the rank of the process is 0, + but the user can define their own function to implement custom filtering logic. + """ + + def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: rank == 0): + super().__init__() + self.filter_fn: Callable = filter_fn + if dist.is_available() and dist.is_initialized(): + self.rank: int = rank if rank is not None else dist.get_rank() + else: + warnings.warn( + "The torch.distributed is either unavailable and uninitiated when RankFilter is instiantiated. " + "If torch.distributed is used, please ensure that the RankFilter() is called " + "after torch.distributed.init_process_group() in the script." + ) + + def filter(self, *_args): + return self.filter_fn(self.rank) diff --git a/tests/min_tests.py b/tests/min_tests.py index 05f117013e..ab5c1db826 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -156,6 +156,7 @@ def run_testsuit(): "test_rand_zoom", "test_rand_zoomd", "test_randtorchvisiond", + "test_rankfilter_dist", "test_resample_backends", "test_resize", "test_resized", diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py new file mode 100644 index 0000000000..4dcd637c56 --- /dev/null +++ b/tests/test_rankfilter_dist.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import tempfile +import unittest + +import torch.distributed as dist + +from monai.utils import RankFilter +from tests.utils import DistCall, DistTestCase + + +class DistributedRankFilterTest(DistTestCase): + def setUp(self): + self.log_dir = tempfile.TemporaryDirectory() + + @DistCall(nnodes=1, nproc_per_node=2) + def test_rankfilter(self): + logger = logging.getLogger(__name__) + log_filename = os.path.join(self.log_dir.name, "records.log") + h1 = logging.FileHandler(filename=log_filename) + h1.setLevel(logging.WARNING) + + logger.addHandler(h1) + + logger.addFilter(RankFilter()) + logger.warning("test_warnings") + + dist.barrier() + if dist.get_rank() == 0: + with open(log_filename) as file: + lines = [line.rstrip() for line in file] + log_message = " ".join(lines) + assert log_message.count("test_warnings") == 1 + + def tearDown(self) -> None: + self.log_dir.cleanup() + + +if __name__ == "__main__": + unittest.main()