From 1402673b32cdcdeb8930ee594b3152f089473546 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Sep 2021 19:52:11 +0800 Subject: [PATCH 1/3] [DLMED] add copy option Signed-off-by: Nic Ma --- monai/data/dataset.py | 16 ++++++++++++++-- tests/test_cachedataset.py | 4 ++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c970e83d0d..2549041fbd 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -575,6 +575,7 @@ def __init__( cache_rate: float = 1.0, num_workers: Optional[int] = None, progress: bool = True, + copy_cache: bool = True, ) -> None: """ Args: @@ -587,11 +588,16 @@ def __init__( num_workers: the number of worker processes to use. If num_workers is None then the number returned by os.cpu_count() is used. progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) self.progress = progress + self.copy_cache = copy_cache self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) self.num_workers = num_workers if self.num_workers is not None: @@ -656,7 +662,8 @@ def _transform(self, index: int): # only need to deep copy data on first non-deterministic transform if not start_run: start_run = True - data = deepcopy(data) + if self.copy_cache: + data = deepcopy(data) data = apply_transform(_transform, data) return data @@ -722,6 +729,10 @@ class SmartCacheDataset(Randomizable, CacheDataset): shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch. it will not modify the original input data sequence in-place. seed: random seed if shuffle is `True`, default to `0`. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ def __init__( @@ -736,6 +747,7 @@ def __init__( progress: bool = True, shuffle: bool = True, seed: int = 0, + copy_cache: bool = True, ) -> None: if shuffle: self.set_random_state(seed=seed) @@ -743,7 +755,7 @@ def __init__( self.randomize(data) self.shuffle = shuffle - super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress) + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache) if self._cache is None: self._cache = self._fill_cache() if self.cache_num >= len(data): diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index bbb8143631..1977283fec 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -92,10 +92,14 @@ def test_set_data(self): cache_rate=1.0, num_workers=4, progress=True, + copy_cache=False if sys.platform == "linux" else True, ) num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i] * 10]], d) + # simulate another epoch, the cache content should not be modified for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list1[i] * 10]], d) From b5d9245abaca545da001bfab2261505b9a6d05b0 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Sep 2021 19:59:54 +0800 Subject: [PATCH 2/3] [DLMED] enhance test Signed-off-by: Nic Ma --- tests/test_cachedataset.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 1977283fec..27d17659ab 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset -from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform +from monai.transforms import Compose, Lambda, LoadImaged, RandLambda, ThreadUnsafe, Transform from monai.utils import get_torch_version_tuple TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)] @@ -84,7 +84,10 @@ def test_shape(self, transform, expected_shape): def test_set_data(self): data_list1 = list(range(10)) - transform = Lambda(func=lambda x: np.array([x * 10])) + transform = Compose([ + Lambda(func=lambda x: np.array([x * 10])), + RandLambda(func=lambda x: x + 1), + ]) dataset = CacheDataset( data=data_list1, @@ -98,17 +101,17 @@ def test_set_data(self): num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list1[i] * 10]], d) + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # simulate another epoch, the cache content should not be modified for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list1[i] * 10]], d) + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) dataset.set_data(data=data_list2) # rerun with updated cache content for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list2[i] * 10]], d) + np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d) class _StatefulTransform(Transform, ThreadUnsafe): From 471079d1af53fa9d9c4c07c4bb10e32c3a46e8c7 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 22 Sep 2021 12:05:10 +0000 Subject: [PATCH 3/3] [MONAI] python code formatting Signed-off-by: monai-bot --- tests/test_cachedataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 27d17659ab..e5bb1b9a90 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -84,10 +84,12 @@ def test_shape(self, transform, expected_shape): def test_set_data(self): data_list1 = list(range(10)) - transform = Compose([ - Lambda(func=lambda x: np.array([x * 10])), - RandLambda(func=lambda x: x + 1), - ]) + transform = Compose( + [ + Lambda(func=lambda x: np.array([x * 10])), + RandLambda(func=lambda x: x + 1), + ] + ) dataset = CacheDataset( data=data_list1,