From 13a914ccc37f2365a72fe8dc9f3a5debf799ba2f Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 17:18:58 +0800 Subject: [PATCH 1/8] fix #5917 Signed-off-by: KumoLiu --- monai/data/utils.py | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 8c5ae88289..56ab3f336c 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,7 +28,7 @@ import numpy as np import torch -from torch.utils.data._utils.collate import default_collate +from torch.utils.data._utils.collate import collate_tensor_fn, default_collate, default_collate_fn_map from monai import config from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike @@ -444,29 +444,18 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): return data -def collate_meta_tensor(batch): +def collate_meta_tensor(batch, *, collate_fn_map=None): """collate a sequence of meta tensor sequences/dictionaries into a single batched metatensor or a dictionary of batched metatensor""" - if not isinstance(batch, Sequence): - raise NotImplementedError() - elem_0 = first(batch) - if isinstance(elem_0, MetaObj): - collated = default_collate(batch) - meta_dicts = [i.meta or TraceKeys.NONE for i in batch] - common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) - if common_: - meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts] - collated.meta = default_collate(meta_dicts) - collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] - collated.is_batch = True - return collated - if isinstance(elem_0, Mapping): - return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0} - if isinstance(elem_0, (tuple, list)): - return [collate_meta_tensor([d[i] for d in batch]) for i in range(len(elem_0))] - - # no more recursive search for MetaTensor - return default_collate(batch) + collated = collate_tensor_fn(batch) + meta_dicts = [i.meta or TraceKeys.NONE for i in batch] + common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) + if common_: + meta_dicts = [{k: d[k] for k in common_} if isinstance(d, dict) else TraceKeys.NONE for d in meta_dicts] + collated.meta = default_collate(meta_dicts) + collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] + collated.is_batch = True + return collated def list_data_collate(batch: Sequence): @@ -479,6 +468,9 @@ def list_data_collate(batch: Sequence): Need to use this collate if apply some transforms that can generate batch data. """ + from monai.data.meta_tensor import MetaTensor + + default_collate_fn_map.update({MetaTensor: collate_meta_tensor}) elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch key = None @@ -490,9 +482,9 @@ def list_data_collate(batch: Sequence): for k in elem: key = k data_for_batch = [d[key] for d in data] - ret[key] = collate_meta_tensor(data_for_batch) + ret[key] = default_collate(data_for_batch) else: - ret = collate_meta_tensor(data) + ret = default_collate(data) return ret except RuntimeError as re: re_str = str(re) From bbc1f7848eb989f85c1431a315b7ae43150a2dd7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 18:01:04 +0800 Subject: [PATCH 2/8] backwards compatibility Signed-off-by: KumoLiu --- monai/data/utils.py | 38 +++++++++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 56ab3f336c..9a5e35fe6b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,7 +28,7 @@ import numpy as np import torch -from torch.utils.data._utils.collate import collate_tensor_fn, default_collate, default_collate_fn_map +from torch.utils.data._utils.collate import default_collate from monai import config from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike @@ -50,8 +50,11 @@ issequenceiterable, look_up_option, optional_import, + pytorch_after, ) +if pytorch_after(1, 13): + from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") nib, _ = optional_import("nibabel") @@ -444,10 +447,11 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): return data -def collate_meta_tensor(batch, *, collate_fn_map=None): +def collate_meta_tensor_fn(batch, *, collate_fn_map=None): """collate a sequence of meta tensor sequences/dictionaries into a single batched metatensor or a dictionary of batched metatensor""" - collated = collate_tensor_fn(batch) + collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate + collated = collate_fn(batch) meta_dicts = [i.meta or TraceKeys.NONE for i in batch] common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) if common_: @@ -458,6 +462,23 @@ def collate_meta_tensor(batch, *, collate_fn_map=None): return collated +def collate_meta_tensor(batch): + """collate a sequence of meta tensor sequences/dictionaries into + a single batched metatensor or a dictionary of batched metatensor""" + if not isinstance(batch, Sequence): + raise NotImplementedError() + elem_0 = first(batch) + if isinstance(elem_0, MetaObj): + return collate_meta_tensor_fn(batch) + if isinstance(elem_0, Mapping): + return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0} + if isinstance(elem_0, (tuple, list)): + return [collate_meta_tensor([d[i] for d in batch]) for i in range(len(elem_0))] + + # no more recursive search for MetaTensor + return default_collate(batch) + + def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. @@ -468,12 +489,15 @@ def list_data_collate(batch: Sequence): Need to use this collate if apply some transforms that can generate batch data. """ - from monai.data.meta_tensor import MetaTensor - default_collate_fn_map.update({MetaTensor: collate_meta_tensor}) + if pytorch_after(1, 13): + from monai.data.meta_tensor import MetaTensor + + default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn}) elem = batch[0] data = [i for k in batch for i in k] if isinstance(elem, list) else batch key = None + collate_fn = default_collate if pytorch_after(1, 13) else collate_meta_tensor try: if config.USE_META_DICT: data = pickle_operations(data) # bc 0.9.0 @@ -482,9 +506,9 @@ def list_data_collate(batch: Sequence): for k in elem: key = k data_for_batch = [d[key] for d in data] - ret[key] = default_collate(data_for_batch) + ret[key] = collate_fn(data_for_batch) else: - ret = default_collate(data) + ret = collate_fn(data) return ret except RuntimeError as re: re_str = str(re) From 70c058db05b6de02597f8d02e43e0aaa5063d1a1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 18:04:20 +0800 Subject: [PATCH 3/8] fix ci Signed-off-by: KumoLiu --- monai/data/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 9a5e35fe6b..0a2ccf1989 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -448,10 +448,9 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): def collate_meta_tensor_fn(batch, *, collate_fn_map=None): - """collate a sequence of meta tensor sequences/dictionaries into - a single batched metatensor or a dictionary of batched metatensor""" + """collate a sequence of meta tensor into a single batched metatensor""" collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate - collated = collate_fn(batch) + collated = collate_fn(batch) # type: ignore meta_dicts = [i.meta or TraceKeys.NONE for i in batch] common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)]) if common_: From ab3009b1f594612164b2f9a72c1f4e9a0382c993 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 26 Oct 2023 19:39:59 +0800 Subject: [PATCH 4/8] Update monai/data/utils.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 0a2ccf1989..7cef572cdf 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -448,7 +448,9 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): def collate_meta_tensor_fn(batch, *, collate_fn_map=None): - """collate a sequence of meta tensor into a single batched metatensor""" + """ + Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` and so should not be used as a collate function directly in dataloaders. + """ collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate collated = collate_fn(batch) # type: ignore meta_dicts = [i.meta or TraceKeys.NONE for i in batch] From 02188d6d34c32464dcb2a908936d09603a9987d1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 21:23:40 +0800 Subject: [PATCH 5/8] minor fix Signed-off-by: KumoLiu --- monai/data/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/data/utils.py b/monai/data/utils.py index 0a2ccf1989..4b3a7846a4 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -490,6 +490,7 @@ def list_data_collate(batch: Sequence): """ if pytorch_after(1, 13): + # needs to go here to avoid circular import from monai.data.meta_tensor import MetaTensor default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn}) From 357769d7e4b0e8c25bddb26b517e4c2bf06bc1ed Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 21:24:30 +0800 Subject: [PATCH 6/8] fix flake8 Signed-off-by: KumoLiu --- monai/data/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index aae6143356..b7e8ee4a3a 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -449,7 +449,8 @@ def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): def collate_meta_tensor_fn(batch, *, collate_fn_map=None): """ - Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` and so should not be used as a collate function directly in dataloaders. + Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` + and so should not be used as a collate function directly in dataloaders. """ collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate collated = collate_fn(batch) # type: ignore From e7f008667ddde91db5a6538f80a448fdeea7739e Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 22:19:49 +0800 Subject: [PATCH 7/8] address comments Signed-off-by: KumoLiu --- monai/data/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/data/utils.py b/monai/data/utils.py index b7e8ee4a3a..5a28a30cb6 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -54,6 +54,7 @@ ) if pytorch_after(1, 13): + # import private code for reuse purposes from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame") From 46ef51821baf8166d9cdc9942378be940581b291 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Thu, 26 Oct 2023 22:21:38 +0800 Subject: [PATCH 8/8] minor fix Signed-off-by: KumoLiu --- monai/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 5a28a30cb6..164fa78814 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -54,7 +54,7 @@ ) if pytorch_after(1, 13): - # import private code for reuse purposes + # import private code for reuse purposes, comment in case things break in the future from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map pd, _ = optional_import("pandas") DataFrame, _ = optional_import("pandas", name="DataFrame")