Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 2 additions & 170 deletions monai/optimizers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,105 +26,6 @@

__all__ = ["LearningRateFinder"]


class DataLoaderIter:
def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None:
if not isinstance(data_loader, DataLoader):
raise ValueError(
f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`"
)
self.data_loader = data_loader
self._iterator = iter(data_loader)
self.image_extractor = image_extractor
self.label_extractor = label_extractor

@property
def dataset(self):
return self.data_loader.dataset

def inputs_labels_from_batch(self, batch_data):
images = self.image_extractor(batch_data)
labels = self.label_extractor(batch_data)
return images, labels

def __iter__(self):
return self

def __next__(self):
batch = next(self._iterator)
return self.inputs_labels_from_batch(batch)


class TrainDataLoaderIter(DataLoaderIter):
def __init__(
self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True
) -> None:
super().__init__(data_loader, image_extractor, label_extractor)
self.auto_reset = auto_reset

def __next__(self):
try:
batch = next(self._iterator)
inputs, labels = self.inputs_labels_from_batch(batch)
except StopIteration:
if not self.auto_reset:
raise
self._iterator = iter(self.data_loader)
batch = next(self._iterator)
inputs, labels = self.inputs_labels_from_batch(batch)

return inputs, labels


class ValDataLoaderIter(DataLoaderIter):
"""This iterator will reset itself **only** when it is acquired by
the syntax of normal `iterator`. That is, this iterator just works
like a `torch.data.DataLoader`. If you want to restart it, you
should use it like:

```
loader_iter = ValDataLoaderIter(data_loader)
for batch in loader_iter:
...

# `loader_iter` should run out of values now, you can restart it by:
# 1. the way we use a `torch.data.DataLoader`
for batch in loader_iter: # __iter__ is called implicitly
...

# 2. passing it into `iter()` manually
loader_iter = iter(loader_iter) # __iter__ is called by `iter()`
```
"""

def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None:
super().__init__(data_loader, image_extractor, label_extractor)
self.run_limit = len(self.data_loader)
self.run_counter = 0

def __iter__(self):
if self.run_counter >= self.run_limit:
self._iterator = iter(self.data_loader)
self.run_counter = 0
return self

def __next__(self):
self.run_counter += 1
return super(ValDataLoaderIter, self).__next__()


def default_image_extractor(x: Any) -> torch.Tensor:
"""Default callable for getting image from batch data."""
out: torch.Tensor = x["image"] if isinstance(x, dict) else x[0]
return out


def default_label_extractor(x: Any) -> torch.Tensor:
"""Default callable for getting label from batch data."""
out: torch.Tensor = x["label"] if isinstance(x, dict) else x[1]
return out


class LearningRateFinder:
"""Learning rate range test.

Expand Down Expand Up @@ -168,7 +69,6 @@ def __init__(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: torch.nn.Module,
device: Optional[Union[str, torch.device]] = None,
memory_cache: bool = True,
cache_dir: Optional[str] = None,
Expand Down Expand Up @@ -201,7 +101,6 @@ def __init__(
self._check_for_scheduler()

self.model = model
self.criterion = criterion
self.history: Dict[str, list] = {"lr": [], "loss": []}
self.memory_cache = memory_cache
self.cache_dir = cache_dir
Expand All @@ -227,18 +126,13 @@ def reset(self) -> None:

def range_test(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
image_extractor: Callable = default_image_extractor,
label_extractor: Callable = default_label_extractor,
train_valid_loss_iter: Callable,
start_lr: Optional[float] = None,
end_lr: int = 10,
num_iter: int = 100,
step_mode: str = "exp",
smooth_f: float = 0.05,
diverge_th: int = 5,
accumulation_steps: int = 1,
non_blocking_transfer: bool = True,
auto_reset: bool = True,
) -> None:
"""Performs the learning rate range test.
Expand Down Expand Up @@ -300,11 +194,6 @@ def range_test(
if smooth_f < 0 or smooth_f >= 1:
raise ValueError("smooth_f is outside the range [0, 1[")

# Create an iterator to get data batch by batch
train_iter = TrainDataLoaderIter(train_loader, image_extractor, label_extractor)
if val_loader:
val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor)

trange: Union[partial[tqdm.trange], Type[range]]
if self.verbose and has_tqdm:
trange = partial(tqdm.trange, desc="Computing optimal learning rate")
Expand All @@ -317,14 +206,7 @@ def range_test(
if self.verbose and not has_tqdm:
print(f"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}")

# Train on batch and retrieve loss
loss = self._train_batch(
train_iter,
accumulation_steps,
non_blocking_transfer=non_blocking_transfer,
)
if val_loader:
loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer)
loss = train_valid_loss_iter()

# Update the learning rate
self.history["lr"].append(lr_schedule.get_lr()[0])
Expand Down Expand Up @@ -369,56 +251,6 @@ def _check_for_scheduler(self) -> _none_or_positive_arg:
if "initial_lr" in param_group:
raise RuntimeError("Optimizer already has a scheduler attached to it")

def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float:
self.model.train()
total_loss = 0

self.optimizer.zero_grad()
for i in range(accumulation_steps):
inputs, labels = next(train_iter)
inputs, labels = copy_to_device([inputs, labels], device=self.device, non_blocking=non_blocking_transfer)

# Forward pass
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)

# Loss should be averaged in each step
loss /= accumulation_steps

# Backward pass
if self.amp and hasattr(self.optimizer, "_amp_stash"):
# For minor performance optimization, see also:
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = ((i + 1) % accumulation_steps) != 0

with torch.cuda.amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: # type: ignore
scaled_loss.backward()
else:
loss.backward()

total_loss += loss.item()

self.optimizer.step()

return total_loss

def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float:
# Set model to evaluation mode and disable gradient computation
running_loss = 0
with eval_mode(self.model):
for inputs, labels in val_iter:
# Copy data to the correct device
inputs, labels = copy_to_device(
[inputs, labels], device=self.device, non_blocking=non_blocking_transfer
)

# Forward pass and loss computation
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
running_loss += loss.item() * len(labels)

return running_loss / len(val_iter.dataset)

def get_lrs_and_losses(
self,
skip_start: int = 0,
Expand Down
Loading