Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def __init__(
cache_dir: Union[Path, str] = "cache",
hash_func: Callable[..., bytes] = pickle_hashing,
db_name: str = "monai_cache",
progress: bool = True,
pickle_protocol=pickle.HIGHEST_PROTOCOL,
lmdb_kwargs: Optional[dict] = None,
) -> None:
Expand All @@ -338,6 +339,7 @@ def __init__(
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
db_name: lmdb database file name. Defaults to "monai_cache".
progress: whether to display a progress bar.
pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
https://docs.python.org/3/library/pickle.html#pickle-protocols
lmdb_kwargs: additional keyword arguments to the lmdb environment.
Expand All @@ -352,15 +354,16 @@ def __init__(
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size
self._read_env = None
self.progress = progress
print(f"Accessing lmdb file: {self.db_file.absolute()}.")

def _fill_cache_start_reader(self):
# create cache
print(f"Accessing lmdb file: {self.db_file.absolute()}.")
self.lmdb_kwargs["readonly"] = False
env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
if not has_tqdm:
warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.")
for item in tqdm(self.data) if has_tqdm else self.data:
for item in tqdm(self.data) if has_tqdm and self.progress else self.data:
key = self.hash_func(item)
done, retry, val = False, 5, None
while not done and retry > 0:
Expand Down
8 changes: 6 additions & 2 deletions tests/test_lmdbdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,15 @@ def test_shape(self, transform, expected_shape, kwargs=None):
]

cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
dataset_precached = LMDBDataset(data=test_data, transform=transform, cache_dir=cache_dir, **kwargs)
dataset_precached = LMDBDataset(
data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs
)
data1_precached = dataset_precached[0]
data2_precached = dataset_precached[1]

dataset_postcached = LMDBDataset(data=test_data, transform=transform, cache_dir=cache_dir, **kwargs)
dataset_postcached = LMDBDataset(
data=test_data, transform=transform, progress=False, cache_dir=cache_dir, **kwargs
)
data1_postcached = dataset_postcached[0]
data2_postcached = dataset_postcached[1]

Expand Down