From cd31b3db9f74e8da94822ea04b8d59272a9713d0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Feb 2021 23:17:40 +0800 Subject: [PATCH 1/7] [DLMED] fix length > 1024 issue in string list all gather Signed-off-by: Nic Ma --- monai/handlers/utils.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 3e36af0652..24f4ba3ce0 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -85,27 +85,38 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) -def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: +def string_list_all_gather(strings: List[str]) -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. + Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024. Args: strings: a list of strings to all gather. - delimiter: use the delimiter to join the string list to be a long string, - then all gather across ranks and split to a list. default to "\t". """ - if idist.get_world_size() <= 1: + world_size = idist.get_world_size() + if world_size <= 1: return strings - _joined = delimiter.join(strings) + result = [[] for _ in range(world_size)] + # get length of strings + length = len(strings) + all_lens = idist.all_gather(length) + max_len = max(all_lens).item() + # pad the item to make sure the same length + if length < max_len: + strings = strings + ["" for _ in range(max_len - length)] + if get_torch_version_tuple() > (1, 6, 0): - # all gather across all ranks - _joined = delimiter.join(idist.all_gather(_joined)) + for s in strings: + gathered = idist.all_gather(s) + for i, g in enumerate(gathered): + if len(g) > 0: + result[i].append(g) else: raise RuntimeError("string all_gather can not be supported in PyTorch < 1.7.0.") - return _joined.split(delimiter) + return [i for k in result for i in k] def write_metrics_reports( From 8c229d7ce2f078e5eede8c0af867d7f1eb7058fd Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 24 Feb 2021 23:32:06 +0800 Subject: [PATCH 2/7] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/handlers/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 24f4ba3ce0..b7bf59517c 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -98,7 +98,7 @@ def string_list_all_gather(strings: List[str]) -> List[str]: if world_size <= 1: return strings - result = [[] for _ in range(world_size)] + result: List[List[str]] = [[] for _ in range(world_size)] # get length of strings length = len(strings) all_lens = idist.all_gather(length) From 02dcbc655749afa68b45fbc4928a0c5b41197e1a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Feb 2021 15:14:02 +0800 Subject: [PATCH 3/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/handlers/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index b7bf59517c..a0717169aa 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -88,7 +88,8 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: def string_list_all_gather(strings: List[str]) -> List[str]: """ Utility function for distributed data parallel to all gather a list of strings. - Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024. + Note that if the item in `strings` is longer than 1024 chars, it will be truncated to 1024: + https://github.com/pytorch/ignite/blob/master/ignite/distributed/comp_models/base.py#L92 Args: strings: a list of strings to all gather. From 67e695fea3e70537544101918a84a712cdf7fe0a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Feb 2021 22:53:55 +0800 Subject: [PATCH 4/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- tests/test_handler_metrics_saver_dist.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 1b17d0adb4..438ce7bf38 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -13,6 +13,7 @@ import csv import os import tempfile +import random import unittest import torch @@ -44,8 +45,13 @@ def _val_func(engine, batch): engine = Engine(_val_func) + # test the case that all_gather length > 1024 chars + filename_prefix = "abcdefghigklmnopqrstuvwxyz" + for i in range(600): + filename_prefix += filename_prefix[random.randint(0, 26)] + if dist.get_rank() == 0: - data = [{"image_meta_dict": {"filename_or_obj": ["filepath1"]}}] + data = [{"image_meta_dict": {"filename_or_obj": [f"{filename_prefix}1"]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): @@ -58,8 +64,8 @@ def _save_metrics0(engine): if dist.get_rank() == 1: # different ranks have different data length data = [ - {"image_meta_dict": {"filename_or_obj": ["filepath2"]}}, - {"image_meta_dict": {"filename_or_obj": ["filepath3"]}}, + {"image_meta_dict": {"filename_or_obj": [f"{filename_prefix}2"]}}, + {"image_meta_dict": {"filename_or_obj": [f"{filename_prefix}3"]}}, ] @engine.on(Events.EPOCH_COMPLETED) @@ -86,7 +92,7 @@ def _save_metrics1(engine): f_csv = csv.reader(f) for i, row in enumerate(f_csv): if i > 0: - self.assertEqual(row, [f"filepath{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + self.assertEqual(row, [f"{filename_prefix}{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) # check the metric_summary.csv and content with open(os.path.join(tempdir, "metric3_summary.csv")) as f: From b841875ecff3be6972e9d91d4f8ed8a32d082fee Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 25 Feb 2021 23:05:39 +0800 Subject: [PATCH 5/7] [DLMED] add more test Signed-off-by: Nic Ma --- tests/test_handler_metrics_saver_dist.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 438ce7bf38..859f5269d6 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -45,13 +45,13 @@ def _val_func(engine, batch): engine = Engine(_val_func) - # test the case that all_gather length > 1024 chars - filename_prefix = "abcdefghigklmnopqrstuvwxyz" - for i in range(600): - filename_prefix += filename_prefix[random.randint(0, 26)] + # test the case that all_gather with string length > 1024 chars + filename_postfix = "abcdefghigklmnopqrstuvwxyz" + for i in range(1100): + filename_postfix += filename_postfix[random.randint(0, 26)] if dist.get_rank() == 0: - data = [{"image_meta_dict": {"filename_or_obj": [f"{filename_prefix}1"]}}] + data = [{"image_meta_dict": {"filename_or_obj": [f"1{filename_postfix}"]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): @@ -64,8 +64,8 @@ def _save_metrics0(engine): if dist.get_rank() == 1: # different ranks have different data length data = [ - {"image_meta_dict": {"filename_or_obj": [f"{filename_prefix}2"]}}, - {"image_meta_dict": {"filename_or_obj": [f"{filename_prefix}3"]}}, + {"image_meta_dict": {"filename_or_obj": [f"2{filename_postfix}"]}}, + {"image_meta_dict": {"filename_or_obj": [f"3{filename_postfix}"]}}, ] @engine.on(Events.EPOCH_COMPLETED) @@ -92,7 +92,8 @@ def _save_metrics1(engine): f_csv = csv.reader(f) for i, row in enumerate(f_csv): if i > 0: - self.assertEqual(row, [f"{filename_prefix}{i}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"]) + expected = [f"{i}{filename_postfix[0: 1023]}\t{float(i)}\t{float(i + 1)}\t{i + 0.5}"] + self.assertEqual(row, expected) self.assertTrue(os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) # check the metric_summary.csv and content with open(os.path.join(tempdir, "metric3_summary.csv")) as f: From c19d421b11df4da703f9cf6566e094edc03aef99 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 25 Feb 2021 15:10:14 +0000 Subject: [PATCH 6/7] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_handler_metrics_saver_dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 859f5269d6..19a75c2ce8 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -12,8 +12,8 @@ import csv import os -import tempfile import random +import tempfile import unittest import torch From 198008346047e841f3b61dbf836db7d7d597db7d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 26 Feb 2021 00:31:25 +0800 Subject: [PATCH 7/7] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- tests/test_handler_metrics_saver_dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py index 19a75c2ce8..dfdaa16526 100644 --- a/tests/test_handler_metrics_saver_dist.py +++ b/tests/test_handler_metrics_saver_dist.py @@ -47,7 +47,7 @@ def _val_func(engine, batch): # test the case that all_gather with string length > 1024 chars filename_postfix = "abcdefghigklmnopqrstuvwxyz" - for i in range(1100): + for _ in range(1100): filename_postfix += filename_postfix[random.randint(0, 26)] if dist.get_rank() == 0: