From 518b9677466d93a3c2c6c8a8c8099ca69ee4f93c Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 7 Aug 2023 15:02:36 +0200 Subject: [PATCH 1/2] Add warning in a multiprocessing special case Signed-off-by: Matthias Hadlich --- monai/data/dataloader.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index 9de0d28b96..7136fad6d8 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -11,14 +11,20 @@ from __future__ import annotations +import warnings + import torch from torch.utils.data import DataLoader as _TorchDataLoader from torch.utils.data import Dataset +from monai.apps.utils import get_logger +from monai.data.meta_obj import get_track_meta from monai.data.utils import list_data_collate, set_rnd, worker_init_fn __all__ = ["DataLoader"] +logger = get_logger(module_name=__name__) + class DataLoader(_TorchDataLoader): """ @@ -88,4 +94,16 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if "worker_init_fn" not in kwargs: kwargs["worker_init_fn"] = worker_init_fn + if ( + "multiprocessing_context" in kwargs + and kwargs["multiprocessing_context"] == "spawn" + and not get_track_meta() + ): + warnings.warn( + "Please be aware: Return type of the dataloader will not be a Tensor as expected but" + " a MetaTensor instead! This is because 'spawn' creates a new process where _TRACK_META" + " is initialized to True again. Context:_TRACK_META is set to False and" + " multiprocessing_context to spawn" + ) + super().__init__(dataset=dataset, num_workers=num_workers, **kwargs) From 0fe5e6dc099efe14629c940a48294e8b79e465d0 Mon Sep 17 00:00:00 2001 From: Matthias Hadlich Date: Mon, 7 Aug 2023 15:26:13 +0200 Subject: [PATCH 2/2] Remove unused logger in monai/data/dataloader.py Signed-off-by: Matthias Hadlich --- monai/data/dataloader.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py index 7136fad6d8..336f55b8c4 100644 --- a/monai/data/dataloader.py +++ b/monai/data/dataloader.py @@ -17,14 +17,11 @@ from torch.utils.data import DataLoader as _TorchDataLoader from torch.utils.data import Dataset -from monai.apps.utils import get_logger from monai.data.meta_obj import get_track_meta from monai.data.utils import list_data_collate, set_rnd, worker_init_fn __all__ = ["DataLoader"] -logger = get_logger(module_name=__name__) - class DataLoader(_TorchDataLoader): """