Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def __init__(
data = self._generate_data_list(dataset_dir)
if transform == ():
transform = LoadImaged("image")
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
)

def randomize(self, data: Optional[Any] = None) -> None:
self.rann = self.R.random()
Expand Down Expand Up @@ -275,7 +277,9 @@ def __init__(
self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
if transform == ():
transform = LoadImaged(["image", "label"])
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
)

def get_indices(self) -> np.ndarray:
"""
Expand Down
21 changes: 13 additions & 8 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ class Dataset(_TorchDataset):
}, }, }]
"""

def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None:
def __init__(self, data: Sequence, transform: Optional[Callable] = None, progress: bool = True) -> None:
Copy link
Contributor

@wyli wyli Feb 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps the progress is not a fundamental property of a Dataset, so it shouldn't be in this base class, what do you think? for a user of the Dataset API, it's difficult to understand this option without looking at the cache dataset and some of the concrete implementations... (progress is not needed in all subclasses)

@Nic-Ma @rijobro sorry I should have had this discussion earlier during my review, somehow I missed it...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right, sorry about that. A couple of choices:

  1. New middleman class. DatasetWithProgress. PersistentDataset and CacheDataset would inherit from it.
  2. Update current documentation for base class: progress: whether to display a progress bar **(if relevant)**.
  3. Revert this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I vote for option 1... @Nic-Ma thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you guys, progress should not be in the base class.

Thanks.

"""
Args:
data: input data to load and transform to generate dataset for model.
transform: a callable data transform on input data.
progress: whether to display a progress bar.
"""
self.data = data
self.transform = transform
self.progress = progress

def __len__(self) -> int:
return len(self.data)
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(
transform: Union[Sequence[Callable], Callable],
cache_dir: Optional[Union[Path, str]] = None,
hash_func: Callable[..., bytes] = pickle_hashing,
progress: bool = True,
) -> None:
"""
Args:
Expand All @@ -129,10 +132,11 @@ def __init__(
If the cache_dir doesn't exist, will automatically create it.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
progress: whether to display a progress bar.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
super().__init__(data=data, transform=transform, progress=progress)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
if self.cache_dir is not None:
Expand Down Expand Up @@ -345,7 +349,7 @@ def __init__(
lmdb_kwargs: additional keyword arguments to the lmdb environment.
for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
"""
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func)
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, progress=progress)
if not self.cache_dir:
raise ValueError("cache_dir must be specified.")
self.db_file = self.cache_dir / f"{db_name}.lmdb"
Expand All @@ -354,14 +358,13 @@ def __init__(
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size
self._read_env = None
self.progress = progress
print(f"Accessing lmdb file: {self.db_file.absolute()}.")

def _fill_cache_start_reader(self):
# create cache
self.lmdb_kwargs["readonly"] = False
env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
if not has_tqdm:
if self.progress and not has_tqdm:
warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.")
for item in tqdm(self.data) if has_tqdm and self.progress else self.data:
key = self.hash_func(item)
Expand Down Expand Up @@ -470,6 +473,7 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: Optional[int] = None,
progress: bool = True,
) -> None:
"""
Args:
Expand All @@ -481,10 +485,11 @@ def __init__(
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker processes to use.
If num_workers is None then the number returned by os.cpu_count() is used.
progress: whether to display a progress bar.
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
super().__init__(data=data, transform=transform, progress=progress)
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
self.num_workers = num_workers
if self.num_workers is not None:
Expand All @@ -494,10 +499,10 @@ def __init__(
def _fill_cache(self) -> List:
if self.cache_num <= 0:
return []
if not has_tqdm:
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
with ThreadPool(self.num_workers) as p:
if has_tqdm:
if self.progress and has_tqdm:
return list(
tqdm(
p.imap(self._load_cache_item, range(self.cache_num)),
Expand Down