From f66b2e4de1df0ec516c7e4b4fe69aa6ed59e2587 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 9 Feb 2021 11:42:41 +0800 Subject: [PATCH 1/3] [DLMED] fix classification saver issue Signed-off-by: Nic Ma --- monai/handlers/classification_saver.py | 4 ++-- monai/handlers/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index a1c76dd338..33ce7c7ec8 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -88,8 +88,8 @@ def __call__(self, engine: Engine) -> None: """ _meta_data = self.batch_transform(engine.state.batch) if Key.FILENAME_OR_OBJ in _meta_data: - # all gather filenames across ranks - _meta_data[Key.FILENAME_OR_OBJ] = string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ]) + # all gather filenames across ranks, only filenames are necessary + _meta_data = {Key.FILENAME_OR_OBJ: string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])} # all gather predictions across ranks _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index d0179e7f49..f4f0250b49 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -103,7 +103,7 @@ def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[st # all gather across all ranks _joined = delimiter.join(idist.all_gather(_joined)) else: - raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") + raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") return _joined.split(delimiter) From 3afea2c9c6ffed3a70148e589161e8c89ccc6357 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 9 Feb 2021 18:52:32 +0800 Subject: [PATCH 2/3] [DLMED] enhance distributed tests Signed-off-by: Nic Ma --- tests/test_handler_classification_saver_dist.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 275d5b2231..b08baff67d 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -41,7 +41,10 @@ def _train_func(engine, batch): saver.attach(engine) # rank 0 has 8 images, rank 1 has 10 images - data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))]}] + data = [{ + "filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))], + "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], + }] engine.run(data, max_epochs=1) filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: From b51b3b231c646ebc150ce4a7b9a6873082a94b7b Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 9 Feb 2021 10:56:14 +0000 Subject: [PATCH 3/3] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_handler_classification_saver_dist.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index b08baff67d..a33cba923a 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -41,10 +41,12 @@ def _train_func(engine, batch): saver.attach(engine) # rank 0 has 8 images, rank 1 has 10 images - data = [{ - "filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))], - "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], - }] + data = [ + { + "filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))], + "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], + } + ] engine.run(data, max_epochs=1) filepath = os.path.join(tempdir, "predictions.csv") if rank == 1: