diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index d8fd815ce9..f0416b8c4f 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -94,7 +94,9 @@ def __init__( data = self._generate_data_list(dataset_dir) if transform == (): transform = LoadImaged("image") - super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) + CacheDataset.__init__( + self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + ) def randomize(self, data: Optional[Any] = None) -> None: self.rann = self.R.random() @@ -275,7 +277,9 @@ def __init__( self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys) if transform == (): transform = LoadImaged(["image", "label"]) - super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers) + CacheDataset.__init__( + self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + ) def get_indices(self) -> np.ndarray: """ diff --git a/monai/data/dataset.py b/monai/data/dataset.py index d2a3ca4a53..b93f03151f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -51,14 +51,16 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None: + def __init__(self, data: Sequence, transform: Optional[Callable] = None, progress: bool = True) -> None: """ Args: data: input data to load and transform to generate dataset for model. transform: a callable data transform on input data. + progress: whether to display a progress bar. """ self.data = data self.transform = transform + self.progress = progress def __len__(self) -> int: return len(self.data) @@ -115,6 +117,7 @@ def __init__( transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]] = None, hash_func: Callable[..., bytes] = pickle_hashing, + progress: bool = True, ) -> None: """ Args: @@ -129,10 +132,11 @@ def __init__( If the cache_dir doesn't exist, will automatically create it. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. + progress: whether to display a progress bar. """ if not isinstance(transform, Compose): transform = Compose(transform) - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=transform, progress=progress) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func if self.cache_dir is not None: @@ -345,7 +349,7 @@ def __init__( lmdb_kwargs: additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, progress=progress) if not self.cache_dir: raise ValueError("cache_dir must be specified.") self.db_file = self.cache_dir / f"{db_name}.lmdb" @@ -354,14 +358,13 @@ def __init__( if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size self._read_env = None - self.progress = progress print(f"Accessing lmdb file: {self.db_file.absolute()}.") def _fill_cache_start_reader(self): # create cache self.lmdb_kwargs["readonly"] = False env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) - if not has_tqdm: + if self.progress and not has_tqdm: warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.") for item in tqdm(self.data) if has_tqdm and self.progress else self.data: key = self.hash_func(item) @@ -470,6 +473,7 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: Optional[int] = None, + progress: bool = True, ) -> None: """ Args: @@ -481,10 +485,11 @@ def __init__( will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker processes to use. If num_workers is None then the number returned by os.cpu_count() is used. + progress: whether to display a progress bar. """ if not isinstance(transform, Compose): transform = Compose(transform) - super().__init__(data=data, transform=transform) + super().__init__(data=data, transform=transform, progress=progress) self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) self.num_workers = num_workers if self.num_workers is not None: @@ -494,10 +499,10 @@ def __init__( def _fill_cache(self) -> List: if self.cache_num <= 0: return [] - if not has_tqdm: + if self.progress and not has_tqdm: warnings.warn("tqdm is not installed, will not show the caching progress bar.") with ThreadPool(self.num_workers) as p: - if has_tqdm: + if self.progress and has_tqdm: return list( tqdm( p.imap(self._load_cache_item, range(self.cache_num)),