diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 6ad4132dd0..4f71c940bb 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -26,105 +26,6 @@ __all__ = ["LearningRateFinder"] - -class DataLoaderIter: - def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: - if not isinstance(data_loader, DataLoader): - raise ValueError( - f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`" - ) - self.data_loader = data_loader - self._iterator = iter(data_loader) - self.image_extractor = image_extractor - self.label_extractor = label_extractor - - @property - def dataset(self): - return self.data_loader.dataset - - def inputs_labels_from_batch(self, batch_data): - images = self.image_extractor(batch_data) - labels = self.label_extractor(batch_data) - return images, labels - - def __iter__(self): - return self - - def __next__(self): - batch = next(self._iterator) - return self.inputs_labels_from_batch(batch) - - -class TrainDataLoaderIter(DataLoaderIter): - def __init__( - self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True - ) -> None: - super().__init__(data_loader, image_extractor, label_extractor) - self.auto_reset = auto_reset - - def __next__(self): - try: - batch = next(self._iterator) - inputs, labels = self.inputs_labels_from_batch(batch) - except StopIteration: - if not self.auto_reset: - raise - self._iterator = iter(self.data_loader) - batch = next(self._iterator) - inputs, labels = self.inputs_labels_from_batch(batch) - - return inputs, labels - - -class ValDataLoaderIter(DataLoaderIter): - """This iterator will reset itself **only** when it is acquired by - the syntax of normal `iterator`. That is, this iterator just works - like a `torch.data.DataLoader`. If you want to restart it, you - should use it like: - - ``` - loader_iter = ValDataLoaderIter(data_loader) - for batch in loader_iter: - ... - - # `loader_iter` should run out of values now, you can restart it by: - # 1. the way we use a `torch.data.DataLoader` - for batch in loader_iter: # __iter__ is called implicitly - ... - - # 2. passing it into `iter()` manually - loader_iter = iter(loader_iter) # __iter__ is called by `iter()` - ``` - """ - - def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: - super().__init__(data_loader, image_extractor, label_extractor) - self.run_limit = len(self.data_loader) - self.run_counter = 0 - - def __iter__(self): - if self.run_counter >= self.run_limit: - self._iterator = iter(self.data_loader) - self.run_counter = 0 - return self - - def __next__(self): - self.run_counter += 1 - return super(ValDataLoaderIter, self).__next__() - - -def default_image_extractor(x: Any) -> torch.Tensor: - """Default callable for getting image from batch data.""" - out: torch.Tensor = x["image"] if isinstance(x, dict) else x[0] - return out - - -def default_label_extractor(x: Any) -> torch.Tensor: - """Default callable for getting label from batch data.""" - out: torch.Tensor = x["label"] if isinstance(x, dict) else x[1] - return out - - class LearningRateFinder: """Learning rate range test. @@ -168,7 +69,6 @@ def __init__( self, model: nn.Module, optimizer: Optimizer, - criterion: torch.nn.Module, device: Optional[Union[str, torch.device]] = None, memory_cache: bool = True, cache_dir: Optional[str] = None, @@ -201,7 +101,6 @@ def __init__( self._check_for_scheduler() self.model = model - self.criterion = criterion self.history: Dict[str, list] = {"lr": [], "loss": []} self.memory_cache = memory_cache self.cache_dir = cache_dir @@ -227,18 +126,13 @@ def reset(self) -> None: def range_test( self, - train_loader: DataLoader, - val_loader: Optional[DataLoader] = None, - image_extractor: Callable = default_image_extractor, - label_extractor: Callable = default_label_extractor, + train_valid_loss_iter: Callable, start_lr: Optional[float] = None, end_lr: int = 10, num_iter: int = 100, step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: int = 5, - accumulation_steps: int = 1, - non_blocking_transfer: bool = True, auto_reset: bool = True, ) -> None: """Performs the learning rate range test. @@ -300,11 +194,6 @@ def range_test( if smooth_f < 0 or smooth_f >= 1: raise ValueError("smooth_f is outside the range [0, 1[") - # Create an iterator to get data batch by batch - train_iter = TrainDataLoaderIter(train_loader, image_extractor, label_extractor) - if val_loader: - val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor) - trange: Union[partial[tqdm.trange], Type[range]] if self.verbose and has_tqdm: trange = partial(tqdm.trange, desc="Computing optimal learning rate") @@ -317,14 +206,7 @@ def range_test( if self.verbose and not has_tqdm: print(f"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}") - # Train on batch and retrieve loss - loss = self._train_batch( - train_iter, - accumulation_steps, - non_blocking_transfer=non_blocking_transfer, - ) - if val_loader: - loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer) + loss = train_valid_loss_iter() # Update the learning rate self.history["lr"].append(lr_schedule.get_lr()[0]) @@ -369,56 +251,6 @@ def _check_for_scheduler(self) -> _none_or_positive_arg: if "initial_lr" in param_group: raise RuntimeError("Optimizer already has a scheduler attached to it") - def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float: - self.model.train() - total_loss = 0 - - self.optimizer.zero_grad() - for i in range(accumulation_steps): - inputs, labels = next(train_iter) - inputs, labels = copy_to_device([inputs, labels], device=self.device, non_blocking=non_blocking_transfer) - - # Forward pass - outputs = self.model(inputs) - loss = self.criterion(outputs, labels) - - # Loss should be averaged in each step - loss /= accumulation_steps - - # Backward pass - if self.amp and hasattr(self.optimizer, "_amp_stash"): - # For minor performance optimization, see also: - # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations - delay_unscale = ((i + 1) % accumulation_steps) != 0 - - with torch.cuda.amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: # type: ignore - scaled_loss.backward() - else: - loss.backward() - - total_loss += loss.item() - - self.optimizer.step() - - return total_loss - - def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float: - # Set model to evaluation mode and disable gradient computation - running_loss = 0 - with eval_mode(self.model): - for inputs, labels in val_iter: - # Copy data to the correct device - inputs, labels = copy_to_device( - [inputs, labels], device=self.device, non_blocking=non_blocking_transfer - ) - - # Forward pass and loss computation - outputs = self.model(inputs) - loss = self.criterion(outputs, labels) - running_loss += loss.item() * len(labels) - - return running_loss / len(val_iter.dataset) - def get_lrs_and_losses( self, skip_start: int = 0, diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 9ee9c8a4d0..1b39f91190 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import os import random import sys @@ -19,9 +20,12 @@ from monai.apps import MedNISTDataset from monai.networks.nets import DenseNet +from monai.networks.utils import eval_mode from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord -from monai.utils import optional_import, set_determinism +from monai.utils import copy_to_device, optional_import, set_determinism + +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union PILImage, has_pil = optional_import("PIL.Image") @@ -32,6 +36,165 @@ device = "cuda" if torch.cuda.is_available() else "cpu" +class DataLoaderIter: + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + if not isinstance(data_loader, DataLoader): + raise ValueError( + f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`" + ) + self.data_loader = data_loader + self._iterator = iter(data_loader) + self.image_extractor = image_extractor + self.label_extractor = label_extractor + + @property + def dataset(self): + return self.data_loader.dataset + + def inputs_labels_from_batch(self, batch_data): + images = self.image_extractor(batch_data) + labels = self.label_extractor(batch_data) + return images, labels + + def __iter__(self): + return self + + def __next__(self): + batch = next(self._iterator) + return self.inputs_labels_from_batch(batch) + + +class TrainDataLoaderIter(DataLoaderIter): + def __init__( + self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True + ) -> None: + super().__init__(data_loader, image_extractor, label_extractor) + self.auto_reset = auto_reset + + def __next__(self): + try: + batch = next(self._iterator) + inputs, labels = self.inputs_labels_from_batch(batch) + except StopIteration: + if not self.auto_reset: + raise + self._iterator = iter(self.data_loader) + batch = next(self._iterator) + inputs, labels = self.inputs_labels_from_batch(batch) + + return inputs, labels + + +class ValDataLoaderIter(DataLoaderIter): + """This iterator will reset itself **only** when it is acquired by + the syntax of normal `iterator`. That is, this iterator just works + like a `torch.data.DataLoader`. If you want to restart it, you + should use it like: + + ``` + loader_iter = ValDataLoaderIter(data_loader) + for batch in loader_iter: + ... + + # `loader_iter` should run out of values now, you can restart it by: + # 1. the way we use a `torch.data.DataLoader` + for batch in loader_iter: # __iter__ is called implicitly + ... + + # 2. passing it into `iter()` manually + loader_iter = iter(loader_iter) # __iter__ is called by `iter()` + ``` + """ + + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + super().__init__(data_loader, image_extractor, label_extractor) + self.run_limit = len(self.data_loader) + self.run_counter = 0 + + def __iter__(self): + if self.run_counter >= self.run_limit: + self._iterator = iter(self.data_loader) + self.run_counter = 0 + return self + + def __next__(self): + self.run_counter += 1 + return super(ValDataLoaderIter, self).__next__() + + +def default_image_extractor(x: Any) -> torch.Tensor: + """Default callable for getting image from batch data.""" + out: torch.Tensor = x["image"] if isinstance(x, dict) else x[0] + return out + + +def default_label_extractor(x: Any) -> torch.Tensor: + """Default callable for getting label from batch data.""" + out: torch.Tensor = x["label"] if isinstance(x, dict) else x[1] + return out + + +def train_valid_loss( + train_iter, + val_iter, + model, + optimizer, + criterion, + accumulation_steps, + device, + non_blocking_transfer, + amp, +): + model.train() + total_loss = 0 + + optimizer.zero_grad() + for i in range(accumulation_steps): + inputs, labels = next(train_iter) + inputs, labels = copy_to_device([inputs, labels], device=device, non_blocking=non_blocking_transfer) + + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + + # Loss should be averaged in each step + loss /= accumulation_steps + + # Backward pass + if amp and hasattr(optimizer, "_amp_stash"): + # For minor performance optimization, see also: + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = ((i + 1) % accumulation_steps) != 0 + + with torch.cuda.amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss: # type: ignore + scaled_loss.backward() + else: + loss.backward() + + total_loss += loss.item() + + optimizer.step() + + if not val_iter: + return total_loss + + # Set model to evaluation mode and disable gradient computation + running_loss = 0 + with eval_mode(model): + for inputs, labels in val_iter: + # Copy data to the correct device + inputs, labels = copy_to_device( + [inputs, labels], device=device, non_blocking=non_blocking_transfer + ) + + # Forward pass and loss computation + outputs = model(inputs) + loss = criterion(outputs, labels) + running_loss += loss.item() * len(labels) + + return running_loss / len(val_iter.dataset) + + @unittest.skipUnless(sys.platform == "linux", "requires linux") @unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): @@ -70,12 +233,31 @@ def test_lr_finder(self): learning_rate = 1e-5 optimizer = torch.optim.Adam(model.parameters(), learning_rate) - lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) - lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) + train_iter = TrainDataLoaderIter(train_loader, default_image_extractor, default_label_extractor) + val_iter = ValDataLoaderIter(train_loader, default_image_extractor, default_label_extractor) + + train_valid_loss_iter = partial( + train_valid_loss, + train_iter, + val_iter, + model = model, + optimizer = optimizer, + criterion = loss_function, + accumulation_steps = 1, + device = device, + non_blocking_transfer = True, + amp = False, + ) + + lr_finder = LearningRateFinder(model, optimizer, device=device) + lr_finder.range_test(train_valid_loss_iter=train_valid_loss_iter, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0]) lr_finder.plot(0, 0) # to inspect the loss-learning rate graph lr_finder.reset() # to reset the model and optimizer to their initial state if __name__ == "__main__": - unittest.main() + # unittest.main() + a = TestLRFinder() + a.setUp() + a.test_lr_finder()