diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index a1c76dd338..33ce7c7ec8 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -88,8 +88,8 @@ def __call__(self, engine: Engine) -> None: """ _meta_data = self.batch_transform(engine.state.batch) if Key.FILENAME_OR_OBJ in _meta_data: - # all gather filenames across ranks - _meta_data[Key.FILENAME_OR_OBJ] = string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ]) + # all gather filenames across ranks, only filenames are necessary + _meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])} # all gather predictions across ranks _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index d0179e7f49..f4f0250b49 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -103,7 +103,7 @@ def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[st # all gather across all ranks _joined = delimiter.join(idist.all_gather(_joined)) else: - raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") + raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") return _joined.split(delimiter) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 275d5b2231..a33cba923a 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -41,7 +41,12 @@ def _train_func(engine, batch): saver.attach(engine) # rank 0 has 8 images, rank 1 has 10 images - data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))]}] + data = [ + { + "filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))], + "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], + } + ] engine.run(data, max_epochs=1) filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: