diff --git a/monai/utils/dist.py b/monai/utils/dist.py index 47e6de4a98..20f09628ac 100644 --- a/monai/utils/dist.py +++ b/monai/utils/dist.py @@ -197,11 +197,13 @@ def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: r if dist.is_available() and dist.is_initialized(): self.rank: int = rank if rank is not None else dist.get_rank() else: - warnings.warn( - "The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated. " - "If torch.distributed is used, please ensure that the RankFilter() is called " - "after torch.distributed.init_process_group() in the script." - ) + if torch.cuda.is_available() and torch.cuda.device_count() > 1: + warnings.warn( + "The torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.\n" + "If torch.distributed is used, please ensure that the RankFilter() is called\n" + "after torch.distributed.init_process_group() in the script.\n" + ) + self.rank = 0 def filter(self, *_args): return self.filter_fn(self.rank) diff --git a/tests/test_rankfilter_dist.py b/tests/test_rankfilter_dist.py index 4dcd637c56..40cd36f31d 100644 --- a/tests/test_rankfilter_dist.py +++ b/tests/test_rankfilter_dist.py @@ -43,11 +43,35 @@ def test_rankfilter(self): with open(log_filename) as file: lines = [line.rstrip() for line in file] log_message = " ".join(lines) - assert log_message.count("test_warnings") == 1 + self.assertEqual(log_message.count("test_warnings"), 1) def tearDown(self) -> None: self.log_dir.cleanup() +class SingleRankFilterTest(unittest.TestCase): + def tearDown(self) -> None: + self.log_dir.cleanup() + + def setUp(self): + self.log_dir = tempfile.TemporaryDirectory() + + def test_rankfilter_single_proc(self): + logger = logging.getLogger(__name__) + log_filename = os.path.join(self.log_dir.name, "records_sp.log") + h1 = logging.FileHandler(filename=log_filename) + h1.setLevel(logging.WARNING) + logger.addHandler(h1) + logger.addFilter(RankFilter()) + logger.warning("test_warnings") + + with open(log_filename) as file: + lines = [line.rstrip() for line in file] + logger.removeHandler(h1) + h1.close() + log_message = " ".join(lines) + self.assertEqual(log_message.count("test_warnings"), 1) + + if __name__ == "__main__": unittest.main()