diff --git a/monai/data/dataset.py b/monai/data/dataset.py index c970e83d0d..2549041fbd 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -575,6 +575,7 @@ def __init__( cache_rate: float = 1.0, num_workers: Optional[int] = None, progress: bool = True, + copy_cache: bool = True, ) -> None: """ Args: @@ -587,11 +588,16 @@ def __init__( num_workers: the number of worker processes to use. If num_workers is None then the number returned by os.cpu_count() is used. progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) self.progress = progress + self.copy_cache = copy_cache self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) self.num_workers = num_workers if self.num_workers is not None: @@ -656,7 +662,8 @@ def _transform(self, index: int): # only need to deep copy data on first non-deterministic transform if not start_run: start_run = True - data = deepcopy(data) + if self.copy_cache: + data = deepcopy(data) data = apply_transform(_transform, data) return data @@ -722,6 +729,10 @@ class SmartCacheDataset(Randomizable, CacheDataset): shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch. it will not modify the original input data sequence in-place. seed: random seed if shuffle is `True`, default to `0`. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ def __init__( @@ -736,6 +747,7 @@ def __init__( progress: bool = True, shuffle: bool = True, seed: int = 0, + copy_cache: bool = True, ) -> None: if shuffle: self.set_random_state(seed=seed) @@ -743,7 +755,7 @@ def __init__( self.randomize(data) self.shuffle = shuffle - super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress) + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache) if self._cache is None: self._cache = self._fill_cache() if self.cache_num >= len(data): diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index bbb8143631..e5bb1b9a90 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset -from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform +from monai.transforms import Compose, Lambda, LoadImaged, RandLambda, ThreadUnsafe, Transform from monai.utils import get_torch_version_tuple TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)] @@ -84,7 +84,12 @@ def test_shape(self, transform, expected_shape): def test_set_data(self): data_list1 = list(range(10)) - transform = Lambda(func=lambda x: np.array([x * 10])) + transform = Compose( + [ + Lambda(func=lambda x: np.array([x * 10])), + RandLambda(func=lambda x: x + 1), + ] + ) dataset = CacheDataset( data=data_list1, @@ -92,19 +97,23 @@ def test_set_data(self): cache_rate=1.0, num_workers=4, progress=True, + copy_cache=False if sys.platform == "linux" else True, ) num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list1[i] * 10]], d) + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) + # simulate another epoch, the cache content should not be modified + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i] * 10 + 1]], d) # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) dataset.set_data(data=data_list2) # rerun with updated cache content for i, d in enumerate(dataloader): - np.testing.assert_allclose([[data_list2[i] * 10]], d) + np.testing.assert_allclose([[data_list2[i] * 10 + 1]], d) class _StatefulTransform(Transform, ThreadUnsafe):