From a82d9f597a39dbe37cc84f747a3fafb4a94946c8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 11 Dec 2020 11:30:00 +0000 Subject: [PATCH] fixes #1353 Signed-off-by: Wenqi Li --- monai/data/dataset.py | 7 +++++-- tests/test_lmdbdataset.py | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index bf2d22f838..892546b2a4 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -321,6 +321,7 @@ def __init__( cache_dir: Union[Path, str] = "cache", hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", + progress: bool = True, pickle_protocol=pickle.HIGHEST_PROTOCOL, lmdb_kwargs: Optional[dict] = None, ) -> None: @@ -338,6 +339,7 @@ def __init__( hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". + progress: whether to display a progress bar. pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. https://docs.python.org/3/library/pickle.html#pickle-protocols lmdb_kwargs: additional keyword arguments to the lmdb environment. @@ -352,15 +354,16 @@ def __init__( if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size self._read_env = None + self.progress = progress + print(f"Accessing lmdb file: {self.db_file.absolute()}.") def _fill_cache_start_reader(self): # create cache - print(f"Accessing lmdb file: {self.db_file.absolute()}.") self.lmdb_kwargs["readonly"] = False env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) if not has_tqdm: warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.") - for item in tqdm(self.data) if has_tqdm else self.data: + for item in tqdm(self.data) if has_tqdm and self.progress else self.data: key = self.hash_func(item) done, retry, val = False, 5, None while not done and retry > 0: diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index b867e31e20..e4d79ad4bd 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -144,11 +144,15 @@ def test_shape(self, transform, expected_shape, kwargs=None): ] cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - dataset_precached = LMDBDataset(data=test_data, transform=transform, cache_dir=cache_dir, **kwargs) + dataset_precached = LMDBDataset( + data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs + ) data1_precached = dataset_precached[0] data2_precached = dataset_precached[1] - dataset_postcached = LMDBDataset(data=test_data, transform=transform, cache_dir=cache_dir, **kwargs) + dataset_postcached = LMDBDataset( + data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs + ) data1_postcached = dataset_postcached[0] data2_postcached = dataset_postcached[1]