From 6f139b2e6c0aa3ee4a504cff20f4f500c1d73faa Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 15 Jan 2021 17:57:17 +0000 Subject: [PATCH 01/25] learning rate finder Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/__init__.py | 1 + monai/optimizers/lr_finder.py | 639 ++++++++++++++++++++++++++++++++++ tests/test_lr_finder.py | 116 ++++++ 3 files changed, 756 insertions(+) create mode 100644 monai/optimizers/lr_finder.py create mode 100644 tests/test_lr_finder.py diff --git a/monai/optimizers/__init__.py b/monai/optimizers/__init__.py index 850627d588..44883de21b 100644 --- a/monai/optimizers/__init__.py +++ b/monai/optimizers/__init__.py @@ -9,5 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .lr_finder import LRFinder from .novograd import Novograd from .utils import generate_param_groups diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py new file mode 100644 index 0000000000..7291023437 --- /dev/null +++ b/monai/optimizers/lr_finder.py @@ -0,0 +1,639 @@ +import copy +import os +import torch +import numpy as np +from torch.optim.lr_scheduler import _LRScheduler +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from functools import partial + + +try: + from tqdm import trange + + trange = partial(trange, desc="Computing optimal learning rate") +except (ImportError, AttributeError): + trange = range + +__all__ = ["LRFinder"] + + +class DataLoaderIter(object): + def __init__(self, data_loader): + self.data_loader = data_loader + self._iterator = iter(data_loader) + + @property + def dataset(self): + return self.data_loader.dataset + + def inputs_labels_from_batch(self, batch_data): + if not isinstance(batch_data, list) and not isinstance(batch_data, tuple): + raise ValueError( + "Your batch type is not supported: {}. Please inherit from " + "`TrainDataLoaderIter` or `ValDataLoaderIter` and override the " + "`inputs_labels_from_batch` method.".format(type(batch_data)) + ) + + inputs, labels, *_ = batch_data + + return inputs, labels + + def __iter__(self): + return self + + def __next__(self): + batch = next(self._iterator) + return self.inputs_labels_from_batch(batch) + + +class TrainDataLoaderIter(DataLoaderIter): + def __init__(self, data_loader, auto_reset=True): + super().__init__(data_loader) + self.auto_reset = auto_reset + + def __next__(self): + try: + batch = next(self._iterator) + inputs, labels = self.inputs_labels_from_batch(batch) + except StopIteration: + if not self.auto_reset: + raise + self._iterator = iter(self.data_loader) + batch = next(self._iterator) + inputs, labels = self.inputs_labels_from_batch(batch) + + return inputs, labels + + +class ValDataLoaderIter(DataLoaderIter): + """This iterator will reset itself **only** when it is acquired by + the syntax of normal `iterator`. That is, this iterator just works + like a `torch.data.DataLoader`. If you want to restart it, you + should use it like: + + ``` + loader_iter = ValDataLoaderIter(data_loader) + for batch in loader_iter: + ... + + # `loader_iter` should run out of values now, you can restart it by: + # 1. the way we use a `torch.data.DataLoader` + for batch in loader_iter: # __iter__ is called implicitly + ... + + # 2. passing it into `iter()` manually + loader_iter = iter(loader_iter) # __iter__ is called by `iter()` + ``` + """ + + def __init__(self, data_loader): + super().__init__(data_loader) + self.run_limit = len(self.data_loader) + self.run_counter = 0 + + def __iter__(self): + if self.run_counter >= self.run_limit: + self._iterator = iter(self.data_loader) + self.run_counter = 0 + return self + + def __next__(self): + self.run_counter += 1 + return super(ValDataLoaderIter, self).__next__() + + +class LRFinder(object): + """Learning rate range test. + + The learning rate range test increases the learning rate in a pre-training run + between two boundaries in a linear or exponential manner. It provides valuable + information on how well the network can be trained over a range of learning rates + and what is the optimal learning rate. + + Arguments: + model (torch.nn.Module): wrapped model. + optimizer (torch.optim.Optimizer): wrapped optimizer where the defined learning + is assumed to be the lower boundary of the range test. + criterion (torch.nn.Module): wrapped loss function. + device (str or torch.device, optional): a string ("cpu" or "cuda") with an + optional ordinal for the device type (e.g. "cuda:X", where is the ordinal). + Alternatively, can be an object representing the device on which the + computation will take place. Default: None, uses the same device as `model`. + memory_cache (boolean, optional): if this flag is set to True, `state_dict` of + model and optimizer will be cached in memory. Otherwise, they will be saved + to files under the `cache_dir`. + cache_dir (string, optional): path for storing temporary files. If no path is + specified, system-wide temporary directory is used. Notice that this + parameter will be ignored if `memory_cache` is True. + + Example: + >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") + >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) + >>> lr_finder.plot() # to inspect the loss-learning rate graph + >>> lr_finder.reset() # to reset the model and optimizer to their initial state + + Reference: + Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + fastai/lr_find: https://github.com/fastai/fastai + """ + + def __init__( + self, + model, + optimizer, + criterion, + device=None, + memory_cache=True, + cache_dir=None, + amp: bool = False, + ): + # Check if the optimizer is already attached to a scheduler + self.optimizer = optimizer + self._check_for_scheduler() + + self.model = model + self.criterion = criterion + self.history = {"lr": [], "loss": []} + self.best_loss = None + self.memory_cache = memory_cache + self.cache_dir = cache_dir + self.amp = amp + + # Save the original state of the model and optimizer so they can be restored if + # needed + self.model_device = next(self.model.parameters()).device + self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + self.state_cacher.store("model", self.model.state_dict()) + self.state_cacher.store("optimizer", self.optimizer.state_dict()) + + # If device is None, use the same as the model + if device: + self.device = device + else: + self.device = self.model_device + + def reset(self): + """Restores the model and optimizer to their initial states.""" + + self.model.load_state_dict(self.state_cacher.retrieve("model")) + self.optimizer.load_state_dict(self.state_cacher.retrieve("optimizer")) + self.model.to(self.model_device) + + def range_test( + self, + train_loader, + val_loader=None, + start_lr=None, + end_lr=10, + num_iter=100, + step_mode="exp", + smooth_f=0.05, + diverge_th=5, + accumulation_steps=1, + non_blocking_transfer=True, + ): + """Performs the learning rate range test. + + Arguments: + train_loader (`torch.utils.data.DataLoader` + or child of `TrainDataLoaderIter`, optional): + the training set data loader. + If your dataset (data loader) returns a tuple (inputs, labels,*) then + Pytorch data loader object can be provided. However, if a dataset + returns different outputs e.g. dicts, then you should inherit + from `TrainDataLoaderIter` class and redefine `inputs_labels_from_batch` + method so that it outputs (inputs, labels). + val_loader (`torch.utils.data.DataLoader` + or child of `ValDataLoaderIter`, optional): if `None` the range test + will only use the training loss. When given a data loader, the model is + evaluated after each iteration on that dataset and the evaluation loss + is used. Note that in this mode the test takes significantly longer but + generally produces more precise results. + Similarly to `train_loader`, if your dataset outputs are not standard + you should inherit from `ValDataLoaderIter` class and + redefine method `inputs_labels_from_batch` so that + it outputs (inputs, labels). Default: None. + start_lr (float, optional): the starting learning rate for the range test. + Default: None (uses the learning rate from the optimizer). + end_lr (float, optional): the maximum learning rate to test. Default: 10. + num_iter (int, optional): the number of iterations over which the test + occurs. Default: 100. + step_mode (str, optional): one of the available learning rate policies, + linear or exponential ("linear", "exp"). Default: "exp". + smooth_f (float, optional): the loss smoothing factor within the [0, 1[ + interval. Disabled if set to 0, otherwise the loss is smoothed using + exponential smoothing. Default: 0.05. + diverge_th (int, optional): the test is stopped when the loss surpasses the + threshold: diverge_th * best_loss. Default: 5. + accumulation_steps (int, optional): steps for gradient accumulation. If it + is 1, gradients are not accumulated. Default: 1. + non_blocking_transfer (bool, optional): when non_blocking_transfer is set, + tries to convert/move data to the device asynchronously if possible, + e.g., moving CPU Tensors with pinned memory to CUDA devices. Default: True. + + Example (fastai approach): + >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") + >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) + + Example (Leslie Smith's approach): + >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") + >>> lr_finder.range_test(trainloader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") + + Gradient accumulation is supported; example: + >>> train_data = ... # prepared dataset + >>> desired_bs, real_bs = 32, 4 # batch size + >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation + >>> dataloader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) + >>> acc_lr_finder = LRFinder(net, optimizer, criterion, device="cuda") + >>> acc_lr_finder.range_test(dataloader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) + + If your DataLoader returns e.g. dict, or other non standard output, intehit from TrainDataLoaderIter, + redefine method `inputs_labels_from_batch` so that it outputs (inputs, lables) data: + >>> import torch_lr_finder + >>> class TrainIter(torch_lr_finder.TrainDataLoaderIter): + >>> def inputs_labels_from_batch(self, batch_data): + >>> return (batch_data['user_features'], batch_data['user_history']), batch_data['y_labels'] + >>> train_data_iter = TrainIter(train_dl) + >>> finder = torch_lr_finder.LRFinder(model, optimizer, partial(model._train_loss, need_one_hot=False)) + >>> finder.range_test(train_data_iter, end_lr=10, num_iter=300, diverge_th=10) + + Reference: + [Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups]( + https://medium.com/huggingface/ec88c3e51255) + [thomwolf/gradient_accumulation](https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3) + """ + + # Reset test results + self.history = {"lr": [], "loss": []} + self.best_loss = None + + # Move the model to the proper device + self.model.to(self.device) + + # Check if the optimizer is already attached to a scheduler + self._check_for_scheduler() + + # Set the starting learning rate + if start_lr: + self._set_learning_rate(start_lr) + + # Initialize the proper learning rate policy + if step_mode.lower() == "exp": + lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) + elif step_mode.lower() == "linear": + lr_schedule = LinearLR(self.optimizer, end_lr, num_iter) + else: + raise ValueError("expected one of (exp, linear), got {}".format(step_mode)) + + if smooth_f < 0 or smooth_f >= 1: + raise ValueError("smooth_f is outside the range [0, 1[") + + # Create an iterator to get data batch by batch + if isinstance(train_loader, DataLoader): + train_iter = TrainDataLoaderIter(train_loader) + elif isinstance(train_loader, TrainDataLoaderIter): + train_iter = train_loader + else: + raise ValueError( + "`train_loader` has unsupported type: {}." + "Expected types are `torch.utils.data.DataLoader`" + "or child of `TrainDataLoaderIter`.".format(type(train_loader)) + ) + + if val_loader: + if isinstance(val_loader, DataLoader): + val_iter = ValDataLoaderIter(val_loader) + elif isinstance(val_loader, ValDataLoaderIter): + val_iter = val_loader + else: + raise ValueError( + "`val_loader` has unsupported type: {}." + "Expected types are `torch.utils.data.DataLoader`" + "or child of `ValDataLoaderIter`.".format(type(val_loader)) + ) + + for iteration in trange(num_iter): + # Train on batch and retrieve loss + loss = self._train_batch( + train_iter, + accumulation_steps, + non_blocking_transfer=non_blocking_transfer, + ) + if val_loader: + loss = self._validate( + val_iter, non_blocking_transfer=non_blocking_transfer + ) + + # Update the learning rate + self.history["lr"].append(lr_schedule.get_lr()[0]) + lr_schedule.step() + + # Track the best loss and smooth it if smooth_f is specified + if iteration == 0: + self.best_loss = loss + else: + if smooth_f > 0: + loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1] + if loss < self.best_loss: + self.best_loss = loss + + # Check if the loss has diverged; if it has, stop the test + self.history["loss"].append(loss) + if loss > diverge_th * self.best_loss: + print("Stopping early, the loss has diverged") + break + + print("Learning rate search finished. See the graph with {finder_name}.plot()") + + def _set_learning_rate(self, new_lrs): + if not isinstance(new_lrs, list): + new_lrs = [new_lrs] * len(self.optimizer.param_groups) + if len(new_lrs) != len(self.optimizer.param_groups): + raise ValueError( + "Length of `new_lrs` is not equal to the number of parameter groups " + + "in the given optimizer" + ) + + for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs): + param_group["lr"] = new_lr + + def _check_for_scheduler(self): + for param_group in self.optimizer.param_groups: + if "initial_lr" in param_group: + raise RuntimeError("Optimizer already has a scheduler attached to it") + + def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=True): + self.model.train() + total_loss = None # for late initialization + + self.optimizer.zero_grad() + for i in range(accumulation_steps): + inputs, labels = next(train_iter) + inputs, labels = self._move_to_device( + inputs, labels, non_blocking=non_blocking_transfer + ) + + # Forward pass + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + # Loss should be averaged in each step + loss /= accumulation_steps + + # Backward pass + if self.amp and hasattr(self.optimizer, "_amp_stash"): + # For minor performance optimization, see also: + # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations + delay_unscale = ((i + 1) % accumulation_steps) != 0 + + with torch.cuda.amp.scale_loss( + loss, self.optimizer, delay_unscale=delay_unscale + ) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + self.optimizer.step() + + return total_loss.item() + + def _move_to_device(self, inputs, labels, non_blocking=True): + def move(obj, device, non_blocking=True): + if hasattr(obj, "to"): + return obj.to(device, non_blocking=non_blocking) + elif isinstance(obj, tuple): + return tuple(move(o, device, non_blocking) for o in obj) + elif isinstance(obj, list): + return [move(o, device, non_blocking) for o in obj] + elif isinstance(obj, dict): + return {k: move(o, device, non_blocking) for k, o in obj.items()} + else: + return obj + + inputs = move(inputs, self.device, non_blocking=non_blocking) + labels = move(labels, self.device, non_blocking=non_blocking) + return inputs, labels + + def _validate(self, val_iter, non_blocking_transfer=True): + # Set model to evaluation mode and disable gradient computation + running_loss = 0 + self.model.eval() + with torch.no_grad(): + for inputs, labels in val_iter: + # Move data to the correct device + inputs, labels = self._move_to_device( + inputs, labels, non_blocking=non_blocking_transfer + ) + + # Forward pass and loss computation + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + running_loss += loss.item() * len(labels) + + return running_loss / len(val_iter.dataset) + + def plot( + self, + skip_start=10, + skip_end=5, + log_lr=True, + show_lr=None, + ax=None, + suggest_lr=True, + ): + """Plots the learning rate range test. + + Arguments: + skip_start (int, optional): number of batches to trim from the start. + Default: 10. + skip_end (int, optional): number of batches to trim from the start. + Default: 5. + log_lr (bool, optional): True to plot the learning rate in a logarithmic + scale; otherwise, plotted in a linear scale. Default: True. + show_lr (float, optional): if set, adds a vertical line to visualize the + specified learning rate. Default: None. + ax (matplotlib.axes.Axes, optional): the plot is created in the specified + matplotlib axes object and the figure is not be shown. If `None`, then + the figure and axes object are created in this method and the figure is + shown . Default: None. + suggest_lr (bool, optional): suggest a learning rate by + - 'steepest': the point with steepest gradient (minimal gradient) + you can use that point as a first guess for an LR. Default: True. + + Returns: + The matplotlib.axes.Axes object that contains the plot, + and the suggested learning rate (if set suggest_lr=True). + """ + + if skip_start < 0: + raise ValueError("skip_start cannot be negative") + if skip_end < 0: + raise ValueError("skip_end cannot be negative") + if show_lr is not None and not isinstance(show_lr, float): + raise ValueError("show_lr must be float") + + # Get the data to plot from the history dictionary. Also, handle skip_end=0 + # properly so the behaviour is the expected + lrs = self.history["lr"] + losses = self.history["loss"] + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + + # Create the figure and axes object if axes was not already given + fig = None + if ax is None: + fig, ax = plt.subplots() + + # Plot loss as a function of the learning rate + ax.plot(lrs, losses) + + # Plot the suggested LR + if suggest_lr: + # 'steepest': the point with steepest gradient (minimal gradient) + print("LR suggestion: steepest gradient") + min_grad_idx = None + try: + min_grad_idx = (np.gradient(np.array(losses))).argmin() + except ValueError: + print( + "Failed to compute the gradients, there might not be enough points." + ) + if min_grad_idx is not None: + print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) + ax.scatter( + lrs[min_grad_idx], + losses[min_grad_idx], + s=75, + marker="o", + color="red", + zorder=3, + label="steepest gradient", + ) + ax.legend() + + if log_lr: + ax.set_xscale("log") + ax.set_xlabel("Learning rate") + ax.set_ylabel("Loss") + + if show_lr is not None: + ax.axvline(x=show_lr, color="red") + + # Show only if the figure was created internally + if fig is not None: + plt.show() + + if suggest_lr and min_grad_idx is not None: + return ax, lrs[min_grad_idx] + else: + return ax + + +class LinearLR(_LRScheduler): + """Linearly increases the learning rate between two boundaries over a number of + iterations. + + Arguments: + optimizer (torch.optim.Optimizer): wrapped optimizer. + end_lr (float): the final learning rate. + num_iter (int): the number of iterations over which the test occurs. + last_epoch (int, optional): the index of last epoch. Default: -1. + """ + + def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): + self.end_lr = end_lr + + if num_iter <= 1: + raise ValueError("`num_iter` must be larger than 1") + self.num_iter = num_iter + + super(LinearLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + + +class ExponentialLR(_LRScheduler): + """Exponentially increases the learning rate between two boundaries over a number of + iterations. + + Arguments: + optimizer (torch.optim.Optimizer): wrapped optimizer. + end_lr (float): the final learning rate. + num_iter (int): the number of iterations over which the test occurs. + last_epoch (int, optional): the index of last epoch. Default: -1. + """ + + def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): + self.end_lr = end_lr + + if num_iter <= 1: + raise ValueError("`num_iter` must be larger than 1") + self.num_iter = num_iter + + super(ExponentialLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] + + +class StateCacher(object): + def __init__(self, in_memory, cache_dir=None): + self.in_memory = in_memory + self.cache_dir = cache_dir + + if self.cache_dir is None: + import tempfile + + self.cache_dir = tempfile.gettempdir() + else: + if not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") + + self.cached = {} + + def store(self, key, state_dict): + if self.in_memory: + self.cached.update({key: copy.deepcopy(state_dict)}) + else: + fn = os.path.join(self.cache_dir, "state_{}_{}.pt".format(key, id(self))) + self.cached.update({key: fn}) + torch.save(state_dict, fn) + + def retrieve(self, key): + if key not in self.cached: + raise KeyError("Target {} was not cached.".format(key)) + + if self.in_memory: + return self.cached.get(key) + else: + fn = self.cached.get(key) + if not os.path.exists(fn): + raise RuntimeError( + "Failed to load state in {}. File doesn't exist anymore.".format(fn) + ) + state_dict = torch.load(fn, map_location=lambda storage, location: storage) + return state_dict + + def __del__(self): + """Check whether there are unused cached files existing in `cache_dir` before + this instance being destroyed.""" + + if self.in_memory: + return + + for k in self.cached: + if os.path.exists(self.cached[k]): + os.remove(self.cached[k]) \ No newline at end of file diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py new file mode 100644 index 0000000000..41dee28338 --- /dev/null +++ b/tests/test_lr_finder.py @@ -0,0 +1,116 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import torch +import random +from torch.utils.data import DataLoader, Dataset + +from monai.optimizers import LRFinder +from monai.networks.nets import densenet121 +from tests.utils import skip_if_quick +from monai.utils import set_determinism +from monai.apps import download_and_extract +from urllib.error import ContentTooShortError, HTTPError +from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor + +TEST_DATA_URL = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" +MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" +TASK = "integration_classification_2d" + +RAND_SEED = 42 +random.seed(RAND_SEED) + +class MedNISTDataset(Dataset): + def __init__(self, image_files, labels, transforms): + self.image_files = image_files + self.labels = labels + self.transforms = transforms + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, index): + return self.transforms(self.image_files[index]), self.labels[index] + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +@skip_if_quick +class TestLRFinder(unittest.TestCase): + + def setUp(self): + set_determinism(seed=0) + + base_data_dir = os.environ.get("MONAI_DATA_DIRECTORY") + if not base_data_dir: + base_data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") + data_dir = os.path.join(base_data_dir, "MedNIST") + dataset_file = os.path.join(base_data_dir, "MedNIST.tar.gz") + + if not os.path.exists(data_dir): + download_and_extract(TEST_DATA_URL, dataset_file, base_data_dir, MD5_VALUE) + self.assertTrue(os.path.exists(data_dir)) + + class_names = sorted((x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))) + image_files_list_list = [ + [os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))] + for class_name in class_names + ] + self.image_files, self.image_classes = [], [] + for i, _ in enumerate(class_names): + self.image_files.extend(image_files_list_list[i]) + self.image_classes.extend([i] * len(image_files_list_list[i])) + + num_to_keep = 100 + c = list(zip(self.image_files, self.image_classes)) + random.shuffle(c) + self.image_files, self.image_classes = zip(*c[:num_to_keep]) + self.num_classes = len(np.unique(self.image_classes)) + + self.train_transforms = Compose( + [ + LoadImage(image_only=True), + AddChannel(), + ScaleIntensity(), + RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), + RandFlip(spatial_axis=0, prob=0.5), + RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), + ToTensor(), + ] + ) + self.train_transforms.set_random_state(RAND_SEED) + + def test_lr_finder(self): + + model = densenet121(spatial_dims=2, in_channels=1, out_channels=self.num_classes).to(device) + loss_function = torch.nn.CrossEntropyLoss() + learning_rate = 1e-5 + optimizer = torch.optim.Adam(model.parameters(), learning_rate) + + train_ds = MedNISTDataset(self.image_files, self.image_classes, self.train_transforms) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) + + print("start") + lr_finder = LRFinder(model, optimizer, loss_function, device=device) + lr_finder.range_test(train_loader, end_lr=100, num_iter=100) + lr_finder.plot() # to inspect the loss-learning rate graph + lr_finder.reset() # to reset the model and optimizer to their initial state + + +if __name__ == "__main__": + # unittest.main() + a = TestLRFinder() + a.setUp() + a.test_lr_finder() From 5c06a4b9f143f8be2691ec6ff570e13a0b867e3d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 15 Jan 2021 19:09:10 +0000 Subject: [PATCH 02/25] tidying Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 116 ++++++++++++++++++++-------------- tests/test_lr_finder.py | 5 +- 2 files changed, 69 insertions(+), 52 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 7291023437..226cfd2a25 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -6,14 +6,10 @@ import matplotlib.pyplot as plt from torch.utils.data import DataLoader from functools import partial +from monai.utils import optional_import +from monai.networks.utils import eval_mode - -try: - from tqdm import trange - - trange = partial(trange, desc="Computing optimal learning rate") -except (ImportError, AttributeError): - trange = range +tqdm, has_tqdm = optional_import("tqdm") __all__ = ["LRFinder"] @@ -104,7 +100,7 @@ def __next__(self): class LRFinder(object): - """Learning rate range test. + """Learning rate range test, modified from https://github.com/davidtvs/pytorch-lr-finder. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable @@ -147,6 +143,7 @@ def __init__( memory_cache=True, cache_dir=None, amp: bool = False, + verbose: bool = True, ): # Check if the optimizer is already attached to a scheduler self.optimizer = optimizer @@ -159,6 +156,7 @@ def __init__( self.memory_cache = memory_cache self.cache_dir = cache_dir self.amp = amp + self.verbose = verbose # Save the original state of the model and optimizer so they can be restored if # needed @@ -168,10 +166,7 @@ def __init__( self.state_cacher.store("optimizer", self.optimizer.state_dict()) # If device is None, use the same as the model - if device: - self.device = device - else: - self.device = self.model_device + self.device = device if device else self.model_device def reset(self): """Restores the model and optimizer to their initial states.""" @@ -313,7 +308,15 @@ def range_test( "or child of `ValDataLoaderIter`.".format(type(val_loader)) ) + if self.verbose and has_tqdm: + trange = partial(tqdm.trange, desc="Computing optimal learning rate") + else: + trange = range + for iteration in trange(num_iter): + if self.verbose and not has_tqdm: + print(f"Computing optimal learning rate, iteration {iteration + 1}/{num_iter}") + # Train on batch and retrieve loss loss = self._train_batch( train_iter, @@ -394,10 +397,7 @@ def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=Tru else: loss.backward() - if total_loss is None: - total_loss = loss - else: - total_loss += loss + total_loss = total_loss + loss if total_loss else loss self.optimizer.step() @@ -423,8 +423,7 @@ def move(obj, device, non_blocking=True): def _validate(self, val_iter, non_blocking_transfer=True): # Set model to evaluation mode and disable gradient computation running_loss = 0 - self.model.eval() - with torch.no_grad(): + with eval_mode(self.model): for inputs, labels in val_iter: # Move data to the correct device inputs, labels = self._move_to_device( @@ -438,14 +437,52 @@ def _validate(self, val_iter, non_blocking_transfer=True): return running_loss / len(val_iter.dataset) + def get_steepest_gradient( + self, + skip_start=10, + skip_end=5, + ): + """Get steepest gradient. + + Arguments: + skip_start (int, optional): number of batches to trim from the start. + Default: 10. + skip_end (int, optional): number of batches to trim from the start. + Default: 5. + + Returns: + Learning rate which has steepest gradient and its corresponding loss + """ + if skip_start < 0: + raise ValueError("skip_start cannot be negative") + if skip_end < 0: + raise ValueError("skip_end cannot be negative") + + # Get the data to plot from the history dictionary. Also, handle skip_end=0 + # properly so the behaviour is the expected + lrs = self.history["lr"] + losses = self.history["loss"] + if skip_end == 0: + lrs = lrs[skip_start:] + losses = losses[skip_start:] + else: + lrs = lrs[skip_start:-skip_end] + losses = losses[skip_start:-skip_end] + + try: + min_grad_idx = np.gradient(np.array(losses)).argmin() + return lrs[min_grad_idx], losses[min_grad_idx] + except ValueError: + print("Failed to compute the gradients, there might not be enough points.") + return None, None + def plot( self, skip_start=10, skip_end=5, log_lr=True, - show_lr=None, ax=None, - suggest_lr=True, + steepest_lr=True, ): """Plots the learning rate range test. @@ -456,15 +493,12 @@ def plot( Default: 5. log_lr (bool, optional): True to plot the learning rate in a logarithmic scale; otherwise, plotted in a linear scale. Default: True. - show_lr (float, optional): if set, adds a vertical line to visualize the - specified learning rate. Default: None. ax (matplotlib.axes.Axes, optional): the plot is created in the specified matplotlib axes object and the figure is not be shown. If `None`, then the figure and axes object are created in this method and the figure is shown . Default: None. - suggest_lr (bool, optional): suggest a learning rate by - - 'steepest': the point with steepest gradient (minimal gradient) - you can use that point as a first guess for an LR. Default: True. + steepest_lr (bool, optional): plot the learning rate which had the steepest + gradient. Default: True. Returns: The matplotlib.axes.Axes object that contains the plot, @@ -475,8 +509,6 @@ def plot( raise ValueError("skip_start cannot be negative") if skip_end < 0: raise ValueError("skip_end cannot be negative") - if show_lr is not None and not isinstance(show_lr, float): - raise ValueError("show_lr must be float") # Get the data to plot from the history dictionary. Also, handle skip_end=0 # properly so the behaviour is the expected @@ -497,22 +529,14 @@ def plot( # Plot loss as a function of the learning rate ax.plot(lrs, losses) - # Plot the suggested LR - if suggest_lr: - # 'steepest': the point with steepest gradient (minimal gradient) - print("LR suggestion: steepest gradient") - min_grad_idx = None - try: - min_grad_idx = (np.gradient(np.array(losses))).argmin() - except ValueError: - print( - "Failed to compute the gradients, there might not be enough points." - ) - if min_grad_idx is not None: - print("Suggested LR: {:.2E}".format(lrs[min_grad_idx])) + # Plot the LR with steepest gradient + if steepest_lr: + lr_at_steepest_grad, loss_at_steepest_grad = \ + self.get_steepest_gradient(skip_start, skip_end) + if lr_at_steepest_grad is not None: ax.scatter( - lrs[min_grad_idx], - losses[min_grad_idx], + lr_at_steepest_grad, + loss_at_steepest_grad, s=75, marker="o", color="red", @@ -526,17 +550,11 @@ def plot( ax.set_xlabel("Learning rate") ax.set_ylabel("Loss") - if show_lr is not None: - ax.axvline(x=show_lr, color="red") - # Show only if the figure was created internally if fig is not None: plt.show() - if suggest_lr and min_grad_idx is not None: - return ax, lrs[min_grad_idx] - else: - return ax + return ax class LinearLR(_LRScheduler): diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 41dee28338..e6c9688672 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -47,7 +47,6 @@ def __getitem__(self, index): device = "cuda" if torch.cuda.is_available() else "cpu" -@skip_if_quick class TestLRFinder(unittest.TestCase): def setUp(self): @@ -73,7 +72,7 @@ def setUp(self): self.image_files.extend(image_files_list_list[i]) self.image_classes.extend([i] * len(image_files_list_list[i])) - num_to_keep = 100 + num_to_keep = 20 c = list(zip(self.image_files, self.image_classes)) random.shuffle(c) self.image_files, self.image_classes = zip(*c[:num_to_keep]) @@ -102,9 +101,9 @@ def test_lr_finder(self): train_ds = MedNISTDataset(self.image_files, self.image_classes, self.train_transforms) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) - print("start") lr_finder = LRFinder(model, optimizer, loss_function, device=device) lr_finder.range_test(train_loader, end_lr=100, num_iter=100) + print(lr_finder.get_steepest_gradient()[0]) lr_finder.plot() # to inspect the loss-learning rate graph lr_finder.reset() # to reset the model and optimizer to their initial state From b5c0248ed168b8c8bd6d7d101ec0a227494e158a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 18 Jan 2021 16:09:55 +0000 Subject: [PATCH 03/25] working Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/dataset.py | 2 +- monai/optimizers/__init__.py | 2 +- monai/optimizers/lr_finder.py | 434 ++++++++++++++++------------------ tests/test_lr_finder.py | 88 ++----- 4 files changed, 228 insertions(+), 298 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 047587119f..99ad338a3d 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -498,7 +498,7 @@ def _fill_cache(self) -> List: warnings.warn("tqdm is not installed, will not show the caching progress bar.") with ThreadPool(self.num_workers) as p: if has_tqdm: - return list(tqdm(p.imap(self._load_cache_item, range(self.cache_num)), total=self.cache_num)) + return list(tqdm(p.imap(self._load_cache_item, range(self.cache_num)), total=self.cache_num, desc="Loading dataset")) return list(p.imap(self._load_cache_item, range(self.cache_num))) def _load_cache_item(self, idx: int): diff --git a/monai/optimizers/__init__.py b/monai/optimizers/__init__.py index 44883de21b..e53aa8d468 100644 --- a/monai/optimizers/__init__.py +++ b/monai/optimizers/__init__.py @@ -9,6 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .lr_finder import LRFinder +from .lr_finder import LearningRateFinder from .novograd import Novograd from .utils import generate_param_groups diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 226cfd2a25..02b9c12acd 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,39 +1,50 @@ import copy import os +from numpy.core.arrayprint import _none_or_positive_arg import torch +import torch.nn as nn import numpy as np +from typing import Any, Tuple, Optional, Union, Callable from torch.optim.lr_scheduler import _LRScheduler -import matplotlib.pyplot as plt from torch.utils.data import DataLoader from functools import partial from monai.utils import optional_import from monai.networks.utils import eval_mode + tqdm, has_tqdm = optional_import("tqdm") +try: + import matplotlib.pyplot as plt + has_matplotlib = True +except ImportError: + has_matplotlib = False -__all__ = ["LRFinder"] +__all__ = ["LearningRateFinder"] class DataLoaderIter(object): - def __init__(self, data_loader): + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + # If already correct type, nothing to do + if isinstance(data_loader, DataLoaderIter): + return self + if not isinstance(data_loader, DataLoader): + raise ValueError( + f"Loader has unsupported type: {type(data_loader)}." + "Expected type was `torch.utils.data.DataLoader`" + ) self.data_loader = data_loader self._iterator = iter(data_loader) + self.image_extractor = image_extractor + self.label_extractor = label_extractor @property def dataset(self): return self.data_loader.dataset def inputs_labels_from_batch(self, batch_data): - if not isinstance(batch_data, list) and not isinstance(batch_data, tuple): - raise ValueError( - "Your batch type is not supported: {}. Please inherit from " - "`TrainDataLoaderIter` or `ValDataLoaderIter` and override the " - "`inputs_labels_from_batch` method.".format(type(batch_data)) - ) - - inputs, labels, *_ = batch_data - - return inputs, labels + images = self.image_extractor(batch_data) + labels = self.label_extractor(batch_data) + return images, labels def __iter__(self): return self @@ -44,8 +55,8 @@ def __next__(self): class TrainDataLoaderIter(DataLoaderIter): - def __init__(self, data_loader, auto_reset=True): - super().__init__(data_loader) + def __init__(self, data_loader:DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool=True) -> None: + super().__init__(data_loader, image_extractor, label_extractor) self.auto_reset = auto_reset def __next__(self): @@ -83,8 +94,8 @@ class ValDataLoaderIter(DataLoaderIter): ``` """ - def __init__(self, data_loader): - super().__init__(data_loader) + def __init__(self, data_loader:DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + super().__init__(data_loader, image_extractor, label_extractor) self.run_limit = len(self.data_loader) self.run_counter = 0 @@ -98,53 +109,77 @@ def __next__(self): self.run_counter += 1 return super(ValDataLoaderIter, self).__next__() - -class LRFinder(object): - """Learning rate range test, modified from https://github.com/davidtvs/pytorch-lr-finder. +class LearningRateFinder(object): + """Learning rate range test. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what is the optimal learning rate. - Arguments: - model (torch.nn.Module): wrapped model. - optimizer (torch.optim.Optimizer): wrapped optimizer where the defined learning - is assumed to be the lower boundary of the range test. - criterion (torch.nn.Module): wrapped loss function. - device (str or torch.device, optional): a string ("cpu" or "cuda") with an - optional ordinal for the device type (e.g. "cuda:X", where is the ordinal). - Alternatively, can be an object representing the device on which the - computation will take place. Default: None, uses the same device as `model`. - memory_cache (boolean, optional): if this flag is set to True, `state_dict` of - model and optimizer will be cached in memory. Otherwise, they will be saved - to files under the `cache_dir`. - cache_dir (string, optional): path for storing temporary files. If no path is - specified, system-wide temporary directory is used. Notice that this - parameter will be ignored if `memory_cache` is True. - - Example: - >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") - >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) + Example (fastai approach): + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100) + >>> lr_finder.get_steepest_gradient() >>> lr_finder.plot() # to inspect the loss-learning rate graph - >>> lr_finder.reset() # to reset the model and optimizer to their initial state - Reference: + Example (Leslie Smith's approach): + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") + + Gradient accumulation is supported; example: + >>> train_data = ... # prepared dataset + >>> desired_bs, real_bs = 32, 4 # batch size + >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation + >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) + >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) + + By default, image will be extracted from data loader with x["image"] and x[0], depending on whether + batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader + returns something other than this, pass a callable function to extract it, e.g.: + >>> image_extractor = lambda x: x["input"] + >>> label_extractor = lambda x: x[100] + >>> lr_finder = LearningRateFinder(net, optimizer, criterion) + >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor) + + References: + Modified from: https://github.com/davidtvs/pytorch-lr-finder. Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 - fastai/lr_find: https://github.com/fastai/fastai """ def __init__( self, - model, - optimizer, - criterion, - device=None, - memory_cache=True, - cache_dir=None, + model: nn.Module, + optimizer: torch.optim.Optimizer, + criterion: torch.nn.Module, + device: Optional[Union[str, torch.device]] = None, + memory_cache:bool=True, + cache_dir:Optional[str]=None, amp: bool = False, verbose: bool = True, - ): + ) -> None: + """Constructor. + + Args: + model: wrapped model. + optimizer: wrapped optimizer. + criterion: wrapped loss function. + device: device on which to test. run a string ("cpu" or "cuda") with an + optional ordinal for the device type (e.g. "cuda:X", where is the ordinal). + Alternatively, can be an object representing the device on which the + computation will take place. Default: None, uses the same device as `model`. + memory_cache: if this flag is set to True, `state_dict` of + model and optimizer will be cached in memory. Otherwise, they will be saved + to files under the `cache_dir`. + cache_dir: path for storing temporary files. If no path is + specified, system-wide temporary directory is used. Notice that this + parameter will be ignored if `memory_cache` is True. + amp: use Automatic Mixed Precision + verbose: verbose output + Returns: + None + """ # Check if the optimizer is already attached to a scheduler self.optimizer = optimizer self._check_for_scheduler() @@ -168,95 +203,64 @@ def __init__( # If device is None, use the same as the model self.device = device if device else self.model_device - def reset(self): + def reset(self) -> None: """Restores the model and optimizer to their initial states.""" self.model.load_state_dict(self.state_cacher.retrieve("model")) self.optimizer.load_state_dict(self.state_cacher.retrieve("optimizer")) self.model.to(self.model_device) + def default_image_extractor(x: Any) -> torch.Tensor: + """Default callable for getting image from batch data.""" + return x["image"] if isinstance(x, dict) else x[0] + + def default_label_extractor(x: Any) -> torch.Tensor: + """Default callable for getting label from batch data.""" + return x["label"] if isinstance(x, dict) else x[1] + def range_test( self, - train_loader, - val_loader=None, - start_lr=None, - end_lr=10, - num_iter=100, - step_mode="exp", - smooth_f=0.05, - diverge_th=5, - accumulation_steps=1, - non_blocking_transfer=True, - ): + train_loader: Union[DataLoader,TrainDataLoaderIter], + val_loader: Optional[Union[DataLoader, ValDataLoaderIter]]=None, + image_extractor: Callable = default_image_extractor, + label_extractor: Callable = default_label_extractor, + start_lr:Optional[float]=None, + end_lr:int=10, + num_iter:int=100, + step_mode:str="exp", + smooth_f:float=0.05, + diverge_th:int=5, + accumulation_steps:int=1, + non_blocking_transfer:bool=True, + auto_reset: bool = True, + ) -> None: """Performs the learning rate range test. - Arguments: - train_loader (`torch.utils.data.DataLoader` - or child of `TrainDataLoaderIter`, optional): - the training set data loader. - If your dataset (data loader) returns a tuple (inputs, labels,*) then - Pytorch data loader object can be provided. However, if a dataset - returns different outputs e.g. dicts, then you should inherit - from `TrainDataLoaderIter` class and redefine `inputs_labels_from_batch` - method so that it outputs (inputs, labels). - val_loader (`torch.utils.data.DataLoader` - or child of `ValDataLoaderIter`, optional): if `None` the range test - will only use the training loss. When given a data loader, the model is - evaluated after each iteration on that dataset and the evaluation loss - is used. Note that in this mode the test takes significantly longer but - generally produces more precise results. - Similarly to `train_loader`, if your dataset outputs are not standard - you should inherit from `ValDataLoaderIter` class and - redefine method `inputs_labels_from_batch` so that - it outputs (inputs, labels). Default: None. - start_lr (float, optional): the starting learning rate for the range test. - Default: None (uses the learning rate from the optimizer). - end_lr (float, optional): the maximum learning rate to test. Default: 10. - num_iter (int, optional): the number of iterations over which the test - occurs. Default: 100. - step_mode (str, optional): one of the available learning rate policies, - linear or exponential ("linear", "exp"). Default: "exp". - smooth_f (float, optional): the loss smoothing factor within the [0, 1[ - interval. Disabled if set to 0, otherwise the loss is smoothed using - exponential smoothing. Default: 0.05. - diverge_th (int, optional): the test is stopped when the loss surpasses the - threshold: diverge_th * best_loss. Default: 5. - accumulation_steps (int, optional): steps for gradient accumulation. If it - is 1, gradients are not accumulated. Default: 1. - non_blocking_transfer (bool, optional): when non_blocking_transfer is set, - tries to convert/move data to the device asynchronously if possible, - e.g., moving CPU Tensors with pinned memory to CUDA devices. Default: True. - - Example (fastai approach): - >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") - >>> lr_finder.range_test(dataloader, end_lr=100, num_iter=100) - - Example (Leslie Smith's approach): - >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") - >>> lr_finder.range_test(trainloader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear") - - Gradient accumulation is supported; example: - >>> train_data = ... # prepared dataset - >>> desired_bs, real_bs = 32, 4 # batch size - >>> accumulation_steps = desired_bs // real_bs # required steps for accumulation - >>> dataloader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True) - >>> acc_lr_finder = LRFinder(net, optimizer, criterion, device="cuda") - >>> acc_lr_finder.range_test(dataloader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps) - - If your DataLoader returns e.g. dict, or other non standard output, intehit from TrainDataLoaderIter, - redefine method `inputs_labels_from_batch` so that it outputs (inputs, lables) data: - >>> import torch_lr_finder - >>> class TrainIter(torch_lr_finder.TrainDataLoaderIter): - >>> def inputs_labels_from_batch(self, batch_data): - >>> return (batch_data['user_features'], batch_data['user_history']), batch_data['y_labels'] - >>> train_data_iter = TrainIter(train_dl) - >>> finder = torch_lr_finder.LRFinder(model, optimizer, partial(model._train_loss, need_one_hot=False)) - >>> finder.range_test(train_data_iter, end_lr=10, num_iter=300, diverge_th=10) - - Reference: - [Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups]( - https://medium.com/huggingface/ec88c3e51255) - [thomwolf/gradient_accumulation](https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3) + Args: + train_loader: training set data loader. + val_loader: validation data loader (if desired). + image_extractor: callable function to get the image from a batch of data. + Default: `x["image"] if isinstance(x, dict) else x[0]`. + label_extractor: callable function to get the label from a batch of data. + Default: `x["label"] if isinstance(x, dict) else x[1]`. + start_lr : the starting learning rate for the range test. + The default is the optimizer's learning rate. + end_lr: the maximum learning rate to test. The test may stop earlier than + this if the result starts diverging. + num_iter: the max number of iterations for test. + step_mode: schedule for increasing learning rate: (`linear` or `exp`). + smooth_f: the loss smoothing factor within the `[0, 1[` interval. Disabled + if set to `0`, otherwise loss is smoothed using exponential smoothing. + diverge_th: test is stopped when loss surpasses threshold: + `diverge_th * best_loss`. + accumulation_steps: steps for gradient accumulation. If set to `1`, + gradients are not accumulated. + non_blocking_transfer: when `True`, moves data to device asynchronously if + possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. + auto_reset: if `True`, returns model and optimizer to original states at end + of test. + Returns: + None """ # Reset test results @@ -273,45 +277,32 @@ def range_test( if start_lr: self._set_learning_rate(start_lr) + # Check number of iterations + if num_iter <= 1: + raise ValueError("`num_iter` must be larger than 1") + # Initialize the proper learning rate policy if step_mode.lower() == "exp": lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) elif step_mode.lower() == "linear": lr_schedule = LinearLR(self.optimizer, end_lr, num_iter) else: - raise ValueError("expected one of (exp, linear), got {}".format(step_mode)) + raise ValueError(f"expected one of (exp, linear), got {step_mode}") if smooth_f < 0 or smooth_f >= 1: raise ValueError("smooth_f is outside the range [0, 1[") # Create an iterator to get data batch by batch - if isinstance(train_loader, DataLoader): - train_iter = TrainDataLoaderIter(train_loader) - elif isinstance(train_loader, TrainDataLoaderIter): - train_iter = train_loader - else: - raise ValueError( - "`train_loader` has unsupported type: {}." - "Expected types are `torch.utils.data.DataLoader`" - "or child of `TrainDataLoaderIter`.".format(type(train_loader)) - ) - + train_iter = TrainDataLoaderIter(train_loader, image_extractor, label_extractor) if val_loader: - if isinstance(val_loader, DataLoader): - val_iter = ValDataLoaderIter(val_loader) - elif isinstance(val_loader, ValDataLoaderIter): - val_iter = val_loader - else: - raise ValueError( - "`val_loader` has unsupported type: {}." - "Expected types are `torch.utils.data.DataLoader`" - "or child of `ValDataLoaderIter`.".format(type(val_loader)) - ) + val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor) if self.verbose and has_tqdm: trange = partial(tqdm.trange, desc="Computing optimal learning rate") + tprint = tqdm.tqdm.write else: trange = range + tprint = print for iteration in trange(num_iter): if self.verbose and not has_tqdm: @@ -344,12 +335,17 @@ def range_test( # Check if the loss has diverged; if it has, stop the test self.history["loss"].append(loss) if loss > diverge_th * self.best_loss: - print("Stopping early, the loss has diverged") + if self.verbose: + tprint("Stopping early, the loss has diverged") break - print("Learning rate search finished. See the graph with {finder_name}.plot()") + if auto_reset: + if self.verbose: + print("Resetting model and optimizer") + self.reset() - def _set_learning_rate(self, new_lrs): + def _set_learning_rate(self, new_lrs: Union[float, list]) -> None: + """Set learning rate(s) for optimizer.""" if not isinstance(new_lrs, list): new_lrs = [new_lrs] * len(self.optimizer.param_groups) if len(new_lrs) != len(self.optimizer.param_groups): @@ -361,12 +357,13 @@ def _set_learning_rate(self, new_lrs): for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs): param_group["lr"] = new_lr - def _check_for_scheduler(self): + def _check_for_scheduler(self) -> _none_or_positive_arg: + """Check optimizer doesn't already have scheduler.""" for param_group in self.optimizer.param_groups: if "initial_lr" in param_group: raise RuntimeError("Optimizer already has a scheduler attached to it") - def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=True): + def _train_batch(self, train_iter, accumulation_steps:int, non_blocking_transfer:bool=True) -> float: self.model.train() total_loss = None # for late initialization @@ -403,7 +400,7 @@ def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=Tru return total_loss.item() - def _move_to_device(self, inputs, labels, non_blocking=True): + def _move_to_device(self, inputs: torch.Tensor, labels: torch.Tensor, non_blocking:bool=True) -> Tuple[torch.Tensor, torch.Tensor]: def move(obj, device, non_blocking=True): if hasattr(obj, "to"): return obj.to(device, non_blocking=non_blocking) @@ -420,7 +417,7 @@ def move(obj, device, non_blocking=True): labels = move(labels, self.device, non_blocking=non_blocking) return inputs, labels - def _validate(self, val_iter, non_blocking_transfer=True): + def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer:bool=True)->float: # Set model to evaluation mode and disable gradient computation running_loss = 0 with eval_mode(self.model): @@ -439,19 +436,17 @@ def _validate(self, val_iter, non_blocking_transfer=True): def get_steepest_gradient( self, - skip_start=10, - skip_end=5, - ): - """Get steepest gradient. + skip_start:int=10, + skip_end:int=5, + ) -> Tuple[float,float]: + """Get learning rate which has steepest gradient and its corresponding loss - Arguments: - skip_start (int, optional): number of batches to trim from the start. - Default: 10. - skip_end (int, optional): number of batches to trim from the start. - Default: 5. + Args: + skip_start: number of batches to trim from the start. + skip_end: number of batches to trim from the end. - Returns: - Learning rate which has steepest gradient and its corresponding loss + Returns: + Learning rate which has steepest gradient and its corresponding loss """ if skip_start < 0: raise ValueError("skip_start cannot be negative") @@ -478,32 +473,29 @@ def get_steepest_gradient( def plot( self, - skip_start=10, - skip_end=5, - log_lr=True, - ax=None, - steepest_lr=True, + skip_start:int=10, + skip_end:int=5, + log_lr:bool=True, + ax:Optional[Any]=None, + steepest_lr:bool=True, ): """Plots the learning rate range test. - Arguments: - skip_start (int, optional): number of batches to trim from the start. - Default: 10. - skip_end (int, optional): number of batches to trim from the start. - Default: 5. - log_lr (bool, optional): True to plot the learning rate in a logarithmic - scale; otherwise, plotted in a linear scale. Default: True. - ax (matplotlib.axes.Axes, optional): the plot is created in the specified - matplotlib axes object and the figure is not be shown. If `None`, then - the figure and axes object are created in this method and the figure is - shown . Default: None. - steepest_lr (bool, optional): plot the learning rate which had the steepest - gradient. Default: True. + Args: + skip_start: number of batches to trim from the start. + skip_end: number of batches to trim from the start. + log_lr: True to plot the learning rate in a logarithmic + scale; otherwise, plotted in a linear scale. + ax: the plot is created in the specified matplotlib axes object and the + figure is not be shown. If `None`, then the figure and axes object are + created in this method and the figure is shown. + steepest_lr: plot the learning rate which had the steepest gradient. Returns: - The matplotlib.axes.Axes object that contains the plot, - and the suggested learning rate (if set suggest_lr=True). + The matplotlib.axes.Axes object that contains the plot """ + if not has_matplotlib: + raise RuntimeError("Matplotlib is missing, can't plot result") if skip_start < 0: raise ValueError("skip_start cannot be negative") @@ -556,59 +548,42 @@ def plot( return ax - -class LinearLR(_LRScheduler): - """Linearly increases the learning rate between two boundaries over a number of - iterations. - - Arguments: - optimizer (torch.optim.Optimizer): wrapped optimizer. - end_lr (float): the final learning rate. - num_iter (int): the number of iterations over which the test occurs. - last_epoch (int, optional): the index of last epoch. Default: -1. - """ - - def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): +class _LRSchedulerMONAI(_LRScheduler): + def __init__(self, optimizer: torch.optim.Optimizer, end_lr:float, num_iter:int, last_epoch:int=-1) -> None: + """ + Args: + optimizer: wrapped optimizer. + end_lr: the final learning rate. + num_iter: the number of iterations over which the test occurs. + last_epoch: the index of last epoch. + Returns: + None + """ self.end_lr = end_lr - - if num_iter <= 1: - raise ValueError("`num_iter` must be larger than 1") self.num_iter = num_iter + super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) - super(LinearLR, self).__init__(optimizer, last_epoch) +class LinearLR(_LRSchedulerMONAI): + """Linearly increases the learning rate between two boundaries over a number of + iterations. + """ def get_lr(self): r = self.last_epoch / (self.num_iter - 1) return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] -class ExponentialLR(_LRScheduler): +class ExponentialLR(_LRSchedulerMONAI): """Exponentially increases the learning rate between two boundaries over a number of iterations. - - Arguments: - optimizer (torch.optim.Optimizer): wrapped optimizer. - end_lr (float): the final learning rate. - num_iter (int): the number of iterations over which the test occurs. - last_epoch (int, optional): the index of last epoch. Default: -1. """ - - def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1): - self.end_lr = end_lr - - if num_iter <= 1: - raise ValueError("`num_iter` must be larger than 1") - self.num_iter = num_iter - - super(ExponentialLR, self).__init__(optimizer, last_epoch) - def get_lr(self): r = self.last_epoch / (self.num_iter - 1) return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] class StateCacher(object): - def __init__(self, in_memory, cache_dir=None): + def __init__(self, in_memory:bool, cache_dir:Optional[str]=None) -> None: self.in_memory = in_memory self.cache_dir = cache_dir @@ -626,13 +601,13 @@ def store(self, key, state_dict): if self.in_memory: self.cached.update({key: copy.deepcopy(state_dict)}) else: - fn = os.path.join(self.cache_dir, "state_{}_{}.pt".format(key, id(self))) + fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt") self.cached.update({key: fn}) torch.save(state_dict, fn) def retrieve(self, key): if key not in self.cached: - raise KeyError("Target {} was not cached.".format(key)) + raise KeyError(f"Target {key} was not cached.") if self.in_memory: return self.cached.get(key) @@ -640,18 +615,15 @@ def retrieve(self, key): fn = self.cached.get(key) if not os.path.exists(fn): raise RuntimeError( - "Failed to load state in {}. File doesn't exist anymore.".format(fn) + f"Failed to load state in {fn}. File doesn't exist anymore." ) state_dict = torch.load(fn, map_location=lambda storage, location: storage) return state_dict def __del__(self): - """Check whether there are unused cached files existing in `cache_dir` before - this instance being destroyed.""" - + """If necessary, delete any cached files existing in `cache_dir`.""" if self.in_memory: return - for k in self.cached: if os.path.exists(self.cached[k]): - os.remove(self.cached[k]) \ No newline at end of file + os.remove(self.cached[k]) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index e6c9688672..ffc843beb6 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -11,19 +11,15 @@ import os import unittest - -import numpy as np import torch import random -from torch.utils.data import DataLoader, Dataset -from monai.optimizers import LRFinder -from monai.networks.nets import densenet121 -from tests.utils import skip_if_quick +from torch.utils.data import DataLoader +from monai.optimizers import LearningRateFinder +from monai.networks.nets import DenseNet +from monai.apps import MedNISTDataset from monai.utils import set_determinism -from monai.apps import download_and_extract -from urllib.error import ContentTooShortError, HTTPError -from monai.transforms import AddChannel, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, ToTensor +from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord TEST_DATA_URL = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" @@ -31,80 +27,42 @@ RAND_SEED = 42 random.seed(RAND_SEED) - -class MedNISTDataset(Dataset): - def __init__(self, image_files, labels, transforms): - self.image_files = image_files - self.labels = labels - self.transforms = transforms - - def __len__(self): - return len(self.image_files) - - def __getitem__(self, index): - return self.transforms(self.image_files[index]), self.labels[index] +set_determinism(seed=RAND_SEED) device = "cuda" if torch.cuda.is_available() else "cpu" - class TestLRFinder(unittest.TestCase): def setUp(self): - set_determinism(seed=0) - - base_data_dir = os.environ.get("MONAI_DATA_DIRECTORY") - if not base_data_dir: - base_data_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") - data_dir = os.path.join(base_data_dir, "MedNIST") - dataset_file = os.path.join(base_data_dir, "MedNIST.tar.gz") - if not os.path.exists(data_dir): - download_and_extract(TEST_DATA_URL, dataset_file, base_data_dir, MD5_VALUE) - self.assertTrue(os.path.exists(data_dir)) + self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY") + if not self.root_dir: + self.root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data") - class_names = sorted((x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x)))) - image_files_list_list = [ - [os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))] - for class_name in class_names - ] - self.image_files, self.image_classes = [], [] - for i, _ in enumerate(class_names): - self.image_files.extend(image_files_list_list[i]) - self.image_classes.extend([i] * len(image_files_list_list[i])) - - num_to_keep = 20 - c = list(zip(self.image_files, self.image_classes)) - random.shuffle(c) - self.image_files, self.image_classes = zip(*c[:num_to_keep]) - self.num_classes = len(np.unique(self.image_classes)) - - self.train_transforms = Compose( + self.transforms = Compose( [ - LoadImage(image_only=True), - AddChannel(), - ScaleIntensity(), - RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True), - RandFlip(spatial_axis=0, prob=0.5), - RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), - ToTensor(), + LoadImaged(keys="image"), + AddChanneld(keys="image"), + ScaleIntensityd(keys="image"), + ToTensord(keys="image"), ] ) - self.train_transforms.set_random_state(RAND_SEED) def test_lr_finder(self): + # 0.001 gives 54 examples + train_ds = MedNISTDataset(root_dir=self.root_dir, transform=self.transforms, section="validation", val_frac=0.001, download=True, num_workers=10) + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) + num_classes = len(set([i['label'] for i in train_ds])) - model = densenet121(spatial_dims=2, in_channels=1, out_channels=self.num_classes).to(device) + model = DenseNet(spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,)) loss_function = torch.nn.CrossEntropyLoss() learning_rate = 1e-5 optimizer = torch.optim.Adam(model.parameters(), learning_rate) - train_ds = MedNISTDataset(self.image_files, self.image_classes, self.train_transforms) - train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) - - lr_finder = LRFinder(model, optimizer, loss_function, device=device) - lr_finder.range_test(train_loader, end_lr=100, num_iter=100) - print(lr_finder.get_steepest_gradient()[0]) - lr_finder.plot() # to inspect the loss-learning rate graph + lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) + lr_finder.range_test(train_loader, end_lr=10, num_iter=5) + print(lr_finder.get_steepest_gradient(0, 0)[0]) + lr_finder.plot(0, 0) # to inspect the loss-learning rate graph lr_finder.reset() # to reset the model and optimizer to their initial state From 8ba4230b293d7a0e8ae754862c69e6c95c73214e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 18 Jan 2021 17:19:58 +0000 Subject: [PATCH 04/25] autofixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/dataset.py | 8 +- monai/optimizers/lr_finder.py | 165 +++++++++++++++++----------------- tests/test_lr_finder.py | 37 ++++---- 3 files changed, 114 insertions(+), 96 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 99ad338a3d..e67c7a2954 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -498,7 +498,13 @@ def _fill_cache(self) -> List: warnings.warn("tqdm is not installed, will not show the caching progress bar.") with ThreadPool(self.num_workers) as p: if has_tqdm: - return list(tqdm(p.imap(self._load_cache_item, range(self.cache_num)), total=self.cache_num, desc="Loading dataset")) + return list( + tqdm( + p.imap(self._load_cache_item, range(self.cache_num)), + total=self.cache_num, + desc="Loading dataset", + ) + ) return list(p.imap(self._load_cache_item, range(self.cache_num))) def _load_cache_item(self, idx: int): diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 02b9c12acd..beea4a7d7a 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,20 +1,28 @@ import copy import os -from numpy.core.arrayprint import _none_or_positive_arg +from functools import partial +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +import numpy as np import torch import torch.nn as nn -import numpy as np -from typing import Any, Tuple, Optional, Union, Callable +from numpy.core.arrayprint import _none_or_positive_arg +from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from functools import partial -from monai.utils import optional_import + from monai.networks.utils import eval_mode +try: + import tqdm -tqdm, has_tqdm = optional_import("tqdm") + has_tqdm = True +except ImportError: + has_tqdm = False try: import matplotlib.pyplot as plt + from matplotlib.axes import Axes + has_matplotlib = True except ImportError: has_matplotlib = False @@ -24,13 +32,9 @@ class DataLoaderIter(object): def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: - # If already correct type, nothing to do - if isinstance(data_loader, DataLoaderIter): - return self if not isinstance(data_loader, DataLoader): raise ValueError( - f"Loader has unsupported type: {type(data_loader)}." - "Expected type was `torch.utils.data.DataLoader`" + f"Loader has unsupported type: {type(data_loader)}. Expected type was `torch.utils.data.DataLoader`" ) self.data_loader = data_loader self._iterator = iter(data_loader) @@ -55,7 +59,9 @@ def __next__(self): class TrainDataLoaderIter(DataLoaderIter): - def __init__(self, data_loader:DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool=True) -> None: + def __init__( + self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable, auto_reset: bool = True + ) -> None: super().__init__(data_loader, image_extractor, label_extractor) self.auto_reset = auto_reset @@ -94,7 +100,7 @@ class ValDataLoaderIter(DataLoaderIter): ``` """ - def __init__(self, data_loader:DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: + def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: super().__init__(data_loader, image_extractor, label_extractor) self.run_limit = len(self.data_loader) self.run_counter = 0 @@ -109,6 +115,19 @@ def __next__(self): self.run_counter += 1 return super(ValDataLoaderIter, self).__next__() + +def default_image_extractor(x: Any) -> torch.Tensor: + """Default callable for getting image from batch data.""" + out: torch.Tensor = x["image"] if isinstance(x, dict) else x[0] + return out + + +def default_label_extractor(x: Any) -> torch.Tensor: + """Default callable for getting label from batch data.""" + out: torch.Tensor = x["label"] if isinstance(x, dict) else x[1] + return out + + class LearningRateFinder(object): """Learning rate range test. @@ -151,11 +170,11 @@ class LearningRateFinder(object): def __init__( self, model: nn.Module, - optimizer: torch.optim.Optimizer, + optimizer: Optimizer, criterion: torch.nn.Module, device: Optional[Union[str, torch.device]] = None, - memory_cache:bool=True, - cache_dir:Optional[str]=None, + memory_cache: bool = True, + cache_dir: Optional[str] = None, amp: bool = False, verbose: bool = True, ) -> None: @@ -186,8 +205,7 @@ def __init__( self.model = model self.criterion = criterion - self.history = {"lr": [], "loss": []} - self.best_loss = None + self.history: Dict[str, list] = {"lr": [], "loss": []} self.memory_cache = memory_cache self.cache_dir = cache_dir self.amp = amp @@ -210,28 +228,20 @@ def reset(self) -> None: self.optimizer.load_state_dict(self.state_cacher.retrieve("optimizer")) self.model.to(self.model_device) - def default_image_extractor(x: Any) -> torch.Tensor: - """Default callable for getting image from batch data.""" - return x["image"] if isinstance(x, dict) else x[0] - - def default_label_extractor(x: Any) -> torch.Tensor: - """Default callable for getting label from batch data.""" - return x["label"] if isinstance(x, dict) else x[1] - def range_test( self, - train_loader: Union[DataLoader,TrainDataLoaderIter], - val_loader: Optional[Union[DataLoader, ValDataLoaderIter]]=None, + train_loader: DataLoader, + val_loader: Optional[DataLoader] = None, image_extractor: Callable = default_image_extractor, label_extractor: Callable = default_label_extractor, - start_lr:Optional[float]=None, - end_lr:int=10, - num_iter:int=100, - step_mode:str="exp", - smooth_f:float=0.05, - diverge_th:int=5, - accumulation_steps:int=1, - non_blocking_transfer:bool=True, + start_lr: Optional[float] = None, + end_lr: int = 10, + num_iter: int = 100, + step_mode: str = "exp", + smooth_f: float = 0.05, + diverge_th: int = 5, + accumulation_steps: int = 1, + non_blocking_transfer: bool = True, auto_reset: bool = True, ) -> None: """Performs the learning rate range test. @@ -265,7 +275,7 @@ def range_test( # Reset test results self.history = {"lr": [], "loss": []} - self.best_loss = None + best_loss = -float("inf") # Move the model to the proper device self.model.to(self.device) @@ -282,6 +292,7 @@ def range_test( raise ValueError("`num_iter` must be larger than 1") # Initialize the proper learning rate policy + lr_schedule: Union[ExponentialLR, LinearLR] if step_mode.lower() == "exp": lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter) elif step_mode.lower() == "linear": @@ -297,6 +308,7 @@ def range_test( if val_loader: val_iter = ValDataLoaderIter(val_loader, image_extractor, label_extractor) + trange: Union[partial[tqdm.trange], Type[range]] if self.verbose and has_tqdm: trange = partial(tqdm.trange, desc="Computing optimal learning rate") tprint = tqdm.tqdm.write @@ -315,9 +327,7 @@ def range_test( non_blocking_transfer=non_blocking_transfer, ) if val_loader: - loss = self._validate( - val_iter, non_blocking_transfer=non_blocking_transfer - ) + loss = self._validate(val_iter, non_blocking_transfer=non_blocking_transfer) # Update the learning rate self.history["lr"].append(lr_schedule.get_lr()[0]) @@ -325,16 +335,16 @@ def range_test( # Track the best loss and smooth it if smooth_f is specified if iteration == 0: - self.best_loss = loss + best_loss = loss else: if smooth_f > 0: loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1] - if loss < self.best_loss: - self.best_loss = loss + if loss < best_loss: + best_loss = loss # Check if the loss has diverged; if it has, stop the test self.history["loss"].append(loss) - if loss > diverge_th * self.best_loss: + if loss > diverge_th * best_loss: if self.verbose: tprint("Stopping early, the loss has diverged") break @@ -350,8 +360,7 @@ def _set_learning_rate(self, new_lrs: Union[float, list]) -> None: new_lrs = [new_lrs] * len(self.optimizer.param_groups) if len(new_lrs) != len(self.optimizer.param_groups): raise ValueError( - "Length of `new_lrs` is not equal to the number of parameter groups " - + "in the given optimizer" + "Length of `new_lrs` is not equal to the number of parameter groups " + "in the given optimizer" ) for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs): @@ -363,16 +372,14 @@ def _check_for_scheduler(self) -> _none_or_positive_arg: if "initial_lr" in param_group: raise RuntimeError("Optimizer already has a scheduler attached to it") - def _train_batch(self, train_iter, accumulation_steps:int, non_blocking_transfer:bool=True) -> float: + def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfer: bool = True) -> float: self.model.train() - total_loss = None # for late initialization + total_loss = 0 self.optimizer.zero_grad() for i in range(accumulation_steps): inputs, labels = next(train_iter) - inputs, labels = self._move_to_device( - inputs, labels, non_blocking=non_blocking_transfer - ) + inputs, labels = self._move_to_device(inputs, labels, non_blocking=non_blocking_transfer) # Forward pass outputs = self.model(inputs) @@ -387,20 +394,20 @@ def _train_batch(self, train_iter, accumulation_steps:int, non_blocking_transfer # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = ((i + 1) % accumulation_steps) != 0 - with torch.cuda.amp.scale_loss( - loss, self.optimizer, delay_unscale=delay_unscale - ) as scaled_loss: + with torch.cuda.amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: # type: ignore scaled_loss.backward() else: loss.backward() - total_loss = total_loss + loss if total_loss else loss + total_loss += loss.item() self.optimizer.step() - return total_loss.item() + return total_loss - def _move_to_device(self, inputs: torch.Tensor, labels: torch.Tensor, non_blocking:bool=True) -> Tuple[torch.Tensor, torch.Tensor]: + def _move_to_device( + self, inputs: torch.Tensor, labels: torch.Tensor, non_blocking: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor]: def move(obj, device, non_blocking=True): if hasattr(obj, "to"): return obj.to(device, non_blocking=non_blocking) @@ -417,15 +424,13 @@ def move(obj, device, non_blocking=True): labels = move(labels, self.device, non_blocking=non_blocking) return inputs, labels - def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer:bool=True)->float: + def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float: # Set model to evaluation mode and disable gradient computation running_loss = 0 with eval_mode(self.model): for inputs, labels in val_iter: # Move data to the correct device - inputs, labels = self._move_to_device( - inputs, labels, non_blocking=non_blocking_transfer - ) + inputs, labels = self._move_to_device(inputs, labels, non_blocking=non_blocking_transfer) # Forward pass and loss computation outputs = self.model(inputs) @@ -436,9 +441,9 @@ def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer:bool=True def get_steepest_gradient( self, - skip_start:int=10, - skip_end:int=5, - ) -> Tuple[float,float]: + skip_start: int = 10, + skip_end: int = 5, + ) -> Union[Tuple[float, float], Tuple[None, None]]: """Get learning rate which has steepest gradient and its corresponding loss Args: @@ -473,12 +478,12 @@ def get_steepest_gradient( def plot( self, - skip_start:int=10, - skip_end:int=5, - log_lr:bool=True, - ax:Optional[Any]=None, - steepest_lr:bool=True, - ): + skip_start: int = 10, + skip_end: int = 5, + log_lr: bool = True, + ax: Optional[Axes] = None, + steepest_lr: bool = True, + ) -> Axes: """Plots the learning rate range test. Args: @@ -523,8 +528,7 @@ def plot( # Plot the LR with steepest gradient if steepest_lr: - lr_at_steepest_grad, loss_at_steepest_grad = \ - self.get_steepest_gradient(skip_start, skip_end) + lr_at_steepest_grad, loss_at_steepest_grad = self.get_steepest_gradient(skip_start, skip_end) if lr_at_steepest_grad is not None: ax.scatter( lr_at_steepest_grad, @@ -548,8 +552,9 @@ def plot( return ax + class _LRSchedulerMONAI(_LRScheduler): - def __init__(self, optimizer: torch.optim.Optimizer, end_lr:float, num_iter:int, last_epoch:int=-1) -> None: + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: """ Args: optimizer: wrapped optimizer. @@ -568,6 +573,7 @@ class LinearLR(_LRSchedulerMONAI): """Linearly increases the learning rate between two boundaries over a number of iterations. """ + def get_lr(self): r = self.last_epoch / (self.num_iter - 1) return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] @@ -577,13 +583,14 @@ class ExponentialLR(_LRSchedulerMONAI): """Exponentially increases the learning rate between two boundaries over a number of iterations. """ + def get_lr(self): r = self.last_epoch / (self.num_iter - 1) return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] class StateCacher(object): - def __init__(self, in_memory:bool, cache_dir:Optional[str]=None) -> None: + def __init__(self, in_memory: bool, cache_dir: Optional[str] = None) -> None: self.in_memory = in_memory self.cache_dir = cache_dir @@ -595,7 +602,7 @@ def __init__(self, in_memory:bool, cache_dir:Optional[str]=None) -> None: if not os.path.isdir(self.cache_dir): raise ValueError("Given `cache_dir` is not a valid directory.") - self.cached = {} + self.cached: Dict[str, str] = {} def store(self, key, state_dict): if self.in_memory: @@ -612,11 +619,9 @@ def retrieve(self, key): if self.in_memory: return self.cached.get(key) else: - fn = self.cached.get(key) - if not os.path.exists(fn): - raise RuntimeError( - f"Failed to load state in {fn}. File doesn't exist anymore." - ) + fn = self.cached.get(key) # pytype: disable=attribute-error + if not os.path.exists(fn): # pytype: disable=wrong-arg-types + raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") state_dict = torch.load(fn, map_location=lambda storage, location: storage) return state_dict diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index ffc843beb6..eab07caae9 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -10,16 +10,17 @@ # limitations under the License. import os -import unittest -import torch import random +import unittest +import torch from torch.utils.data import DataLoader -from monai.optimizers import LearningRateFinder -from monai.networks.nets import DenseNet + from monai.apps import MedNISTDataset -from monai.utils import set_determinism +from monai.networks.nets import DenseNet +from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord +from monai.utils import set_determinism TEST_DATA_URL = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" @@ -31,8 +32,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -class TestLRFinder(unittest.TestCase): +class TestLRFinder(unittest.TestCase): def setUp(self): self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY") @@ -50,11 +51,20 @@ def setUp(self): def test_lr_finder(self): # 0.001 gives 54 examples - train_ds = MedNISTDataset(root_dir=self.root_dir, transform=self.transforms, section="validation", val_frac=0.001, download=True, num_workers=10) + train_ds = MedNISTDataset( + root_dir=self.root_dir, + transform=self.transforms, + section="validation", + val_frac=0.001, + download=True, + num_workers=10, + ) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) - num_classes = len(set([i['label'] for i in train_ds])) + num_classes = len({i["label"] for i in train_ds}) - model = DenseNet(spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,)) + model = DenseNet( + spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,) + ) loss_function = torch.nn.CrossEntropyLoss() learning_rate = 1e-5 optimizer = torch.optim.Adam(model.parameters(), learning_rate) @@ -62,12 +72,9 @@ def test_lr_finder(self): lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) lr_finder.range_test(train_loader, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0]) - lr_finder.plot(0, 0) # to inspect the loss-learning rate graph - lr_finder.reset() # to reset the model and optimizer to their initial state + lr_finder.plot(0, 0) # to inspect the loss-learning rate graph + lr_finder.reset() # to reset the model and optimizer to their initial state if __name__ == "__main__": - # unittest.main() - a = TestLRFinder() - a.setUp() - a.test_lr_finder() + unittest.main() From cc63970d4d24b05919ee764b5ff4181847551e64 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 09:35:48 +0000 Subject: [PATCH 05/25] fix type checking problem Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index beea4a7d7a..d56a7f6bfc 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,7 +1,7 @@ import copy import os from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, TYPE_CHECKING, Union import numpy as np import torch @@ -13,20 +13,23 @@ from monai.networks.utils import eval_mode -try: +if TYPE_CHECKING: + try: + import tqdm + has_tqdm = True + except ImportError: + has_tqdm = False + try: + import matplotlib.pyplot as plt + from matplotlib.axes import Axes + has_matplotlib = True + except ImportError: + has_matplotlib = False +else: import tqdm - - has_tqdm = True -except ImportError: - has_tqdm = False -try: import matplotlib.pyplot as plt from matplotlib.axes import Axes - has_matplotlib = True -except ImportError: - has_matplotlib = False - __all__ = ["LearningRateFinder"] From dbe04092a447549b934ae1b58eaa1ff30e3b3ad6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 09:47:06 +0000 Subject: [PATCH 06/25] undo typing changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index d56a7f6bfc..0fff4ea509 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,7 +1,7 @@ import copy import os from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple, Type, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch @@ -13,23 +13,20 @@ from monai.networks.utils import eval_mode -if TYPE_CHECKING: - try: - import tqdm - has_tqdm = True - except ImportError: - has_tqdm = False - try: - import matplotlib.pyplot as plt - from matplotlib.axes import Axes - has_matplotlib = True - except ImportError: - has_matplotlib = False -else: +try: import tqdm + + has_tqdm = True +except ImportError: + has_tqdm = False +try: import matplotlib.pyplot as plt from matplotlib.axes import Axes + has_matplotlib = True +except ImportError: + has_matplotlib = False + __all__ = ["LearningRateFinder"] From fdd0c2dd09d0227576dd7e37f8709e407b01619a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 09:53:08 +0000 Subject: [PATCH 07/25] more undo Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 0fff4ea509..beea4a7d7a 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,7 +1,7 @@ import copy import os from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch From 883946c5f3c372ce8402d338faa0751d2bb78cdc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 10:04:18 +0000 Subject: [PATCH 08/25] try to fix CI/CD Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index beea4a7d7a..32b030e0e5 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -25,6 +25,7 @@ has_matplotlib = True except ImportError: + Axes = None has_matplotlib = False __all__ = ["LearningRateFinder"] From 39250172bfd3a4fabefecb76502f220ee429f4e0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 10:05:40 +0000 Subject: [PATCH 09/25] try to fix CI/CD Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 32b030e0e5..83f5918f48 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -21,11 +21,8 @@ has_tqdm = False try: import matplotlib.pyplot as plt - from matplotlib.axes import Axes - has_matplotlib = True except ImportError: - Axes = None has_matplotlib = False __all__ = ["LearningRateFinder"] @@ -482,9 +479,9 @@ def plot( skip_start: int = 10, skip_end: int = 5, log_lr: bool = True, - ax: Optional[Axes] = None, + ax = None, steepest_lr: bool = True, - ) -> Axes: + ): """Plots the learning rate range test. Args: From 94d0d5e621eefa7f7b6569d3e9d32d59d7a9c687 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 10:10:59 +0000 Subject: [PATCH 10/25] more fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 83f5918f48..50b02f1242 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -21,6 +21,7 @@ has_tqdm = False try: import matplotlib.pyplot as plt + has_matplotlib = True except ImportError: has_matplotlib = False @@ -479,7 +480,7 @@ def plot( skip_start: int = 10, skip_end: int = 5, log_lr: bool = True, - ax = None, + ax=None, steepest_lr: bool = True, ): """Plots the learning rate range test. From 53a1b383c34e8201205282f0fae04c1cbde2bf89 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 11:20:54 +0000 Subject: [PATCH 11/25] only do unit test if PIL is installed Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 62 ++++++++++++++++------------------- tests/test_lr_finder.py | 7 ++-- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 50b02f1242..6fa81ad10e 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -438,35 +438,45 @@ def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = T return running_loss / len(val_iter.dataset) - def get_steepest_gradient( + def get_lrs_and_losses( self, - skip_start: int = 10, - skip_end: int = 5, - ) -> Union[Tuple[float, float], Tuple[None, None]]: - """Get learning rate which has steepest gradient and its corresponding loss + skip_start: int = 0, + skip_end: int = 0, + ) -> Tuple[list,list]: + """Get learning rates and their corresponding losses Args: skip_start: number of batches to trim from the start. skip_end: number of batches to trim from the end. - - Returns: - Learning rate which has steepest gradient and its corresponding loss """ if skip_start < 0: raise ValueError("skip_start cannot be negative") if skip_end < 0: raise ValueError("skip_end cannot be negative") - # Get the data to plot from the history dictionary. Also, handle skip_end=0 - # properly so the behaviour is the expected lrs = self.history["lr"] losses = self.history["loss"] - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] + end_idx = len(lrs) - skip_end - 1 + lrs = lrs[skip_start:end_idx] + losses = losses[skip_start:end_idx] + + return lrs, losses + + def get_steepest_gradient( + self, + skip_start: int = 0, + skip_end: int = 0, + ) -> Union[Tuple[float, float], Tuple[None, None]]: + """Get learning rate which has steepest gradient and its corresponding loss + + Args: + skip_start: number of batches to trim from the start. + skip_end: number of batches to trim from the end. + + Returns: + Learning rate which has steepest gradient and its corresponding loss + """ + lrs, losses = self.get_lrs_and_losses(skip_start, skip_end) try: min_grad_idx = np.gradient(np.array(losses)).argmin() @@ -477,8 +487,8 @@ def get_steepest_gradient( def plot( self, - skip_start: int = 10, - skip_end: int = 5, + skip_start: int = 0, + skip_end: int = 0, log_lr: bool = True, ax=None, steepest_lr: bool = True, @@ -501,21 +511,7 @@ def plot( if not has_matplotlib: raise RuntimeError("Matplotlib is missing, can't plot result") - if skip_start < 0: - raise ValueError("skip_start cannot be negative") - if skip_end < 0: - raise ValueError("skip_end cannot be negative") - - # Get the data to plot from the history dictionary. Also, handle skip_end=0 - # properly so the behaviour is the expected - lrs = self.history["lr"] - losses = self.history["loss"] - if skip_end == 0: - lrs = lrs[skip_start:] - losses = losses[skip_start:] - else: - lrs = lrs[skip_start:-skip_end] - losses = losses[skip_start:-skip_end] + lrs, losses = self.get_lrs_and_losses(skip_start, skip_end) # Create the figure and axes object if axes was not already given fig = None diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index eab07caae9..e67290e282 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -20,11 +20,9 @@ from monai.networks.nets import DenseNet from monai.optimizers import LearningRateFinder from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord -from monai.utils import set_determinism +from monai.utils import optional_import, set_determinism -TEST_DATA_URL = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" -MD5_VALUE = "0bc7306e7427e00ad1c5526a6677552d" -TASK = "integration_classification_2d" +PILImage, has_pil = optional_import("PIL.Image") RAND_SEED = 42 random.seed(RAND_SEED) @@ -33,6 +31,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" +@unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): def setUp(self): From 9ec1beb30479063671fc03d4ccb9f6d3afaad687 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 11:24:21 +0000 Subject: [PATCH 12/25] autofix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 6fa81ad10e..2913be3aa2 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -442,7 +442,7 @@ def get_lrs_and_losses( self, skip_start: int = 0, skip_end: int = 0, - ) -> Tuple[list,list]: + ) -> Tuple[list, list]: """Get learning rates and their corresponding losses Args: From 3bcdddd064959bd3446211b7ad3aee5b4e009fcc Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 19 Jan 2021 11:40:18 +0000 Subject: [PATCH 13/25] enhance mednist dataset Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/apps/datasets.py | 17 ++++++++++++----- tests/test_mednistdataset.py | 1 + 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 1291dac25a..d8fd815ce9 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -83,6 +83,7 @@ def __init__( self.set_random_state(seed=seed) tarfile_name = os.path.join(root_dir, self.compressed_file_name) dataset_dir = os.path.join(root_dir, self.dataset_folder_name) + self.num_class = 0 if download: download_and_extract(self.resource, tarfile_name, root_dir, self.md5) @@ -98,6 +99,10 @@ def __init__( def randomize(self, data: Optional[Any] = None) -> None: self.rann = self.R.random() + def get_num_classes(self) -> int: + """Get number of classes.""" + return self.num_class + def _generate_data_list(self, dataset_dir: str) -> List[Dict]: """ Raises: @@ -105,20 +110,22 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: """ class_names = sorted((x for x in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, x)))) - num_class = len(class_names) + self.num_class = len(class_names) image_files = [ [ os.path.join(dataset_dir, class_names[i], x) for x in os.listdir(os.path.join(dataset_dir, class_names[i])) ] - for i in range(num_class) + for i in range(self.num_class) ] - num_each = [len(image_files[i]) for i in range(num_class)] + num_each = [len(image_files[i]) for i in range(self.num_class)] image_files_list = [] image_class = [] - for i in range(num_class): + class_name = [] + for i in range(self.num_class): image_files_list.extend(image_files[i]) image_class.extend([i] * num_each[i]) + class_name.extend([class_names[i]] * num_each[i]) num_total = len(image_class) data = [] @@ -138,7 +145,7 @@ def _generate_data_list(self, dataset_dir: str) -> List[Dict]: raise ValueError( f'Unsupported section: {self.section}, available options are ["training", "validation", "test"].' ) - data.append({"image": image_files_list[i], "label": image_class[i]}) + data.append({"image": image_files_list[i], "label": image_class[i], "class_name": class_name[i]}) return data diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 28263e0722..0887734a7c 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -52,6 +52,7 @@ def _test_dataset(dataset): # testing from data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False) + data.get_num_classes() _test_dataset(data) data = MedNISTDataset(root_dir=testing_dir, section="test", download=False) self.assertTupleEqual(data[0]["image"].shape, (64, 64)) From 7d615cfb696ce89fd2efeb904417dce650849800 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Jan 2021 18:01:11 +0000 Subject: [PATCH 14/25] only plot if on linux Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_lr_finder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index e67290e282..c45e6b2879 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -11,6 +11,7 @@ import os import random +import sys import unittest import torch @@ -59,7 +60,7 @@ def test_lr_finder(self): num_workers=10, ) train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) - num_classes = len({i["label"] for i in train_ds}) + num_classes = train_ds.get_num_classes() model = DenseNet( spatial_dims=2, in_channels=1, out_channels=num_classes, init_features=2, growth_rate=2, block_config=(2,) @@ -69,9 +70,10 @@ def test_lr_finder(self): optimizer = torch.optim.Adam(model.parameters(), learning_rate) lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) - lr_finder.range_test(train_loader, end_lr=10, num_iter=5) + lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0]) - lr_finder.plot(0, 0) # to inspect the loss-learning rate graph + if sys.platform == "linux": + lr_finder.plot(0, 0) # to inspect the loss-learning rate graph lr_finder.reset() # to reset the model and optimizer to their initial state From 229fdf045ef8cb1e9ce6a03eb1dd9caef7ad4e5b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Jan 2021 21:14:52 +0000 Subject: [PATCH 15/25] only test if linux Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_lr_finder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index c45e6b2879..4d7b44bd12 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -32,6 +32,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" +@unittest.skipUnless(sys.platform == "linux") @unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): def setUp(self): @@ -72,8 +73,7 @@ def test_lr_finder(self): lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device) lr_finder.range_test(train_loader, val_loader=train_loader, end_lr=10, num_iter=5) print(lr_finder.get_steepest_gradient(0, 0)[0]) - if sys.platform == "linux": - lr_finder.plot(0, 0) # to inspect the loss-learning rate graph + lr_finder.plot(0, 0) # to inspect the loss-learning rate graph lr_finder.reset() # to reset the model and optimizer to their initial state From c8d4689a89e5c42bb82019aa4dc56bd0104035af Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Jan 2021 21:20:57 +0000 Subject: [PATCH 16/25] reason Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_lr_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 4d7b44bd12..9ee9c8a4d0 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -32,7 +32,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu" -@unittest.skipUnless(sys.platform == "linux") +@unittest.skipUnless(sys.platform == "linux", "requires linux") @unittest.skipUnless(has_pil, "requires PIL") class TestLRFinder(unittest.TestCase): def setUp(self): From d27172cf0d3ff8ec44c536dbf79c3979a5ccc86f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 13:32:39 +0000 Subject: [PATCH 17/25] refactor move_to_device code Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 28 +++++------------------ monai/utils/__init__.py | 1 + monai/utils/misc.py | 43 ++++++++++++++++++++++++++++++++++- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 2913be3aa2..004580c27c 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader from monai.networks.utils import eval_mode +from monai.utils import copy_to_device try: import tqdm @@ -378,7 +379,7 @@ def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfe self.optimizer.zero_grad() for i in range(accumulation_steps): inputs, labels = next(train_iter) - inputs, labels = self._move_to_device(inputs, labels, non_blocking=non_blocking_transfer) + inputs, labels = copy_to_device([inputs, labels], device=self.device, non_blocking=non_blocking_transfer) # Forward pass outputs = self.model(inputs) @@ -404,32 +405,15 @@ def _train_batch(self, train_iter, accumulation_steps: int, non_blocking_transfe return total_loss - def _move_to_device( - self, inputs: torch.Tensor, labels: torch.Tensor, non_blocking: bool = True - ) -> Tuple[torch.Tensor, torch.Tensor]: - def move(obj, device, non_blocking=True): - if hasattr(obj, "to"): - return obj.to(device, non_blocking=non_blocking) - elif isinstance(obj, tuple): - return tuple(move(o, device, non_blocking) for o in obj) - elif isinstance(obj, list): - return [move(o, device, non_blocking) for o in obj] - elif isinstance(obj, dict): - return {k: move(o, device, non_blocking) for k, o in obj.items()} - else: - return obj - - inputs = move(inputs, self.device, non_blocking=non_blocking) - labels = move(labels, self.device, non_blocking=non_blocking) - return inputs, labels - def _validate(self, val_iter: ValDataLoaderIter, non_blocking_transfer: bool = True) -> float: # Set model to evaluation mode and disable gradient computation running_loss = 0 with eval_mode(self.model): for inputs, labels in val_iter: - # Move data to the correct device - inputs, labels = self._move_to_device(inputs, labels, non_blocking=non_blocking_transfer) + # Copy data to the correct device + inputs, labels = copy_to_device( + [inputs, labels], device=self.device, non_blocking=non_blocking_transfer + ) # Forward pass and loss computation outputs = self.model(inputs) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 9bb25d723a..6430fae75a 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -32,6 +32,7 @@ ) from .misc import ( MAX_SEED, + copy_to_device, dtype_numpy_to_torch, dtype_torch_to_numpy, ensure_tuple, diff --git a/monai/utils/misc.py b/monai/utils/misc.py index bf1ff60cbc..2b31392a46 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -10,11 +10,14 @@ # limitations under the License. import collections.abc +import inspect import itertools import random +import types +import warnings from ast import literal_eval from distutils.util import strtobool -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -37,6 +40,7 @@ "dtype_torch_to_numpy", "dtype_numpy_to_torch", "MAX_SEED", + "copy_to_device", ] _seed = None @@ -306,3 +310,40 @@ def dtype_torch_to_numpy(dtype): def dtype_numpy_to_torch(dtype): """Convert a numpy dtype to its torch equivalent.""" return _np_to_torch_dtype[dtype] + + +def copy_to_device( + obj: Any, + device: Optional[Union[str, torch.device]], + non_blocking: bool = True, + verbose: bool = False, +) -> Any: + """ + Copy object or tuple/list/dictionary of objects to ``device``. + + Args: + obj: object or tuple/list/dictionary of objects to move to ``device``. + device: move ``obj`` to this device. Can be a string (e.g., ``cpu``, ``cuda``, + ``cuda:0``, etc.) or of type ``torch.device``. + non_blocking_transfer: when `True`, moves data to device asynchronously if + possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. + verbose: when `True`, will print a warning for any elements of incompatible type + not copied to ``device``. + Returns: + Same as input, copied to ``device`` where possible. Original input will be + unchanged. + """ + + if hasattr(obj, "to"): + return obj.to(device, non_blocking=non_blocking) + elif isinstance(obj, tuple): + return tuple(copy_to_device(o, device, non_blocking) for o in obj) + elif isinstance(obj, list): + return [copy_to_device(o, device, non_blocking) for o in obj] + elif isinstance(obj, dict): + return {k: copy_to_device(o, device, non_blocking) for k, o in obj.items()} + elif verbose: + fn_name = cast(types.FrameType, inspect.currentframe()).f_code.co_name + warnings.warn(f"{fn_name} called with incompatible type: " + f"{type(obj)}. Data will be returned unchanged.") + + return obj From 478081b2a4a7442aff2eff304ef346a50b971143 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 15:47:03 +0000 Subject: [PATCH 18/25] optional_import and no error if no matplotlib Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 004580c27c..f59f9f852e 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -2,6 +2,7 @@ import os from functools import partial from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +import warnings import numpy as np import torch @@ -12,20 +13,10 @@ from torch.utils.data import DataLoader from monai.networks.utils import eval_mode -from monai.utils import copy_to_device +from monai.utils import copy_to_device, optional_import -try: - import tqdm - - has_tqdm = True -except ImportError: - has_tqdm = False -try: - import matplotlib.pyplot as plt - - has_matplotlib = True -except ImportError: - has_matplotlib = False +plt, has_matplotlib = optional_import("matplotlib.pyplot") +tqdm, has_tqdm = optional_import("tqdm") __all__ = ["LearningRateFinder"] @@ -490,10 +481,12 @@ def plot( steepest_lr: plot the learning rate which had the steepest gradient. Returns: - The matplotlib.axes.Axes object that contains the plot + The `matplotlib.axes.Axes` object that contains the plot. Returns `None` if + `matplotlib` is not installed. """ if not has_matplotlib: - raise RuntimeError("Matplotlib is missing, can't plot result") + warnings.warn("Matplotlib is missing, can't plot result") + return None lrs, losses = self.get_lrs_and_losses(skip_start, skip_end) @@ -597,12 +590,12 @@ def retrieve(self, key): if self.in_memory: return self.cached.get(key) - else: - fn = self.cached.get(key) # pytype: disable=attribute-error - if not os.path.exists(fn): # pytype: disable=wrong-arg-types - raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") - state_dict = torch.load(fn, map_location=lambda storage, location: storage) - return state_dict + + fn = self.cached.get(key) # pytype: disable=attribute-error + if not os.path.exists(fn): # pytype: disable=wrong-arg-types + raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") + state_dict = torch.load(fn, map_location=lambda storage, location: storage) + return state_dict def __del__(self): """If necessary, delete any cached files existing in `cache_dir`.""" From 5cc0a5b68d0465eeeb6bbd0bc6639b8230e7cc59 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 16:02:54 +0000 Subject: [PATCH 19/25] move schedulers to own file Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 38 +---------------------------- monai/optimizers/lr_scheduler.py | 41 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 37 deletions(-) create mode 100644 monai/optimizers/lr_scheduler.py diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index f59f9f852e..abaf0386c6 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -9,11 +9,11 @@ import torch.nn as nn from numpy.core.arrayprint import _none_or_positive_arg from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from monai.networks.utils import eval_mode from monai.utils import copy_to_device, optional_import +from monai.optimizers.lr_scheduler import ExponentialLR, LinearLR plt, has_matplotlib = optional_import("matplotlib.pyplot") tqdm, has_tqdm = optional_import("tqdm") @@ -525,42 +525,6 @@ def plot( return ax -class _LRSchedulerMONAI(_LRScheduler): - def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: - """ - Args: - optimizer: wrapped optimizer. - end_lr: the final learning rate. - num_iter: the number of iterations over which the test occurs. - last_epoch: the index of last epoch. - Returns: - None - """ - self.end_lr = end_lr - self.num_iter = num_iter - super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) - - -class LinearLR(_LRSchedulerMONAI): - """Linearly increases the learning rate between two boundaries over a number of - iterations. - """ - - def get_lr(self): - r = self.last_epoch / (self.num_iter - 1) - return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] - - -class ExponentialLR(_LRSchedulerMONAI): - """Exponentially increases the learning rate between two boundaries over a number of - iterations. - """ - - def get_lr(self): - r = self.last_epoch / (self.num_iter - 1) - return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] - - class StateCacher(object): def __init__(self, in_memory: bool, cache_dir: Optional[str] = None) -> None: self.in_memory = in_memory diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py new file mode 100644 index 0000000000..2daafa6933 --- /dev/null +++ b/monai/optimizers/lr_scheduler.py @@ -0,0 +1,41 @@ +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +__all__ = ["LinearLR", "ExponentialLR"] + + +class _LRSchedulerMONAI(_LRScheduler): + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: + """ + Args: + optimizer: wrapped optimizer. + end_lr: the final learning rate. + num_iter: the number of iterations over which the test occurs. + last_epoch: the index of last epoch. + Returns: + None + """ + self.end_lr = end_lr + self.num_iter = num_iter + super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) + + +class LinearLR(_LRSchedulerMONAI): + """Linearly increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] + + +class ExponentialLR(_LRSchedulerMONAI): + """Exponentially increases the learning rate between two boundaries over a number of + iterations. + """ + + def get_lr(self): + r = self.last_epoch / (self.num_iter - 1) + return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] From b287a6eee8d95e33dc65f596944b22b0232d3725 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 16:04:52 +0000 Subject: [PATCH 20/25] add scheduler docstring Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index 2daafa6933..e90b185698 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -6,6 +6,8 @@ class _LRSchedulerMONAI(_LRScheduler): + """Base class for increasing the learning rate between two boundaries over a number + of iterations""" def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: """ Args: From 89f4d61d448cc385d68830e5707a973e563cd7b7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 16:06:41 +0000 Subject: [PATCH 21/25] move import to top of file Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index abaf0386c6..658b66320c 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,6 +1,7 @@ import copy import os from functools import partial +import tempfile from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import warnings @@ -531,8 +532,6 @@ def __init__(self, in_memory: bool, cache_dir: Optional[str] = None) -> None: self.cache_dir = cache_dir if self.cache_dir is None: - import tempfile - self.cache_dir = tempfile.gettempdir() else: if not os.path.isdir(self.cache_dir): From d26311ed012c65a58b908c1b11952b0ae15a609e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 17:57:10 +0000 Subject: [PATCH 22/25] isolate state_cacher Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 45 +---------------- monai/utils/__init__.py | 1 + monai/utils/state_cacher.py | 91 +++++++++++++++++++++++++++++++++++ tests/test_state_cacher.py | 68 ++++++++++++++++++++++++++ 4 files changed, 161 insertions(+), 44 deletions(-) create mode 100644 monai/utils/state_cacher.py create mode 100644 tests/test_state_cacher.py diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 658b66320c..fa1f1a3991 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from monai.networks.utils import eval_mode -from monai.utils import copy_to_device, optional_import +from monai.utils import copy_to_device, optional_import, StateCacher from monai.optimizers.lr_scheduler import ExponentialLR, LinearLR plt, has_matplotlib = optional_import("matplotlib.pyplot") @@ -524,46 +524,3 @@ def plot( plt.show() return ax - - -class StateCacher(object): - def __init__(self, in_memory: bool, cache_dir: Optional[str] = None) -> None: - self.in_memory = in_memory - self.cache_dir = cache_dir - - if self.cache_dir is None: - self.cache_dir = tempfile.gettempdir() - else: - if not os.path.isdir(self.cache_dir): - raise ValueError("Given `cache_dir` is not a valid directory.") - - self.cached: Dict[str, str] = {} - - def store(self, key, state_dict): - if self.in_memory: - self.cached.update({key: copy.deepcopy(state_dict)}) - else: - fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt") - self.cached.update({key: fn}) - torch.save(state_dict, fn) - - def retrieve(self, key): - if key not in self.cached: - raise KeyError(f"Target {key} was not cached.") - - if self.in_memory: - return self.cached.get(key) - - fn = self.cached.get(key) # pytype: disable=attribute-error - if not os.path.exists(fn): # pytype: disable=wrong-arg-types - raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") - state_dict = torch.load(fn, map_location=lambda storage, location: storage) - return state_dict - - def __del__(self): - """If necessary, delete any cached files existing in `cache_dir`.""" - if self.in_memory: - return - for k in self.cached: - if os.path.exists(self.cached[k]): - os.remove(self.cached[k]) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 6430fae75a..e5567f9f16 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -65,3 +65,4 @@ optional_import, ) from .profiling import PerfContext, torch_profiler_full, torch_profiler_time_cpu_gpu, torch_profiler_time_end_to_end +from .state_cacher import StateCacher diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py new file mode 100644 index 0000000000..62aac940aa --- /dev/null +++ b/monai/utils/state_cacher.py @@ -0,0 +1,91 @@ +import copy +import os + +import tempfile +from typing import Dict, Optional +import torch + +__all__ = ["StateCacher"] + + +class StateCacher(object): + """Class to cache and retrieve the state of an object. + + Objects can either be stored in memory or on disk. If stored on disk, they can be + stored in a given directory, or alternatively a temporary location will be used. + + If necessary, restored objects will be returned to their original device. + + Example: + + >>> state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) + >>> state_cacher.store("model", model.state_dict()) + >>> model.load_state_dict(state_cacher.retrieve("model")) + """ + def __init__( + self, + in_memory: bool, + cache_dir: Optional[str] = None, + allow_overwrite: bool = True, + ) -> None: + """Constructor. + + Args: + in_memory: boolean to determine if the object will be cached in memory or on + disk. + cache_dir: directory for data to be cached if `in_memory==False`. Defaults + to using a temporary directory. Any created files will be deleted during + the `StateCacher`'s destructor. + allow_overwrite: allow the cache to be overwritten. If set to `False`, an + error will be thrown if a matching already exists in the list of cached + objects. + """ + self.in_memory = in_memory + self.cache_dir = cache_dir + self.allow_overwrite = allow_overwrite + + if self.cache_dir is None: + self.cache_dir = tempfile.gettempdir() + else: + if not os.path.isdir(self.cache_dir): + raise ValueError("Given `cache_dir` is not a valid directory.") + + self.cached: Dict[str, str] = {} + + def store(self, key, data_obj): + """Store a given object with the given key name.""" + if key in self.cached and not self.allow_overwrite: + raise RuntimeError("Cached key already exists and overwriting is disabled.") + if self.in_memory: + self.cached.update({key: {"obj": copy.deepcopy(data_obj)}}) + else: + fn = os.path.join(self.cache_dir, f"state_{key}_{id(self)}.pt") + self.cached.update({key: {"obj": fn}}) + torch.save(data_obj, fn) + # store object's device if relevant + if hasattr(data_obj, "device"): + self.cached[key]["device"] = data_obj.device + + def retrieve(self, key): + """Retrieve the object stored under a given key name.""" + if key not in self.cached: + raise KeyError(f"Target {key} was not cached.") + + if self.in_memory: + return self.cached[key]["obj"] + + fn = self.cached[key]["obj"] # pytype: disable=attribute-error + if not os.path.exists(fn): # pytype: disable=wrong-arg-types + raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") + data_obj = torch.load(fn, map_location=lambda storage, location: storage) + # copy back to device if necessary + if "device" in self.cached[key]: + data_obj = data_obj.to(self.cached[key]["device"]) + return data_obj + + def __del__(self): + """If necessary, delete any cached files existing in `cache_dir`.""" + if not self.in_memory: + for k in self.cached: + if os.path.exists(self.cached[k]["obj"]): + os.remove(self.cached[k]["obj"]) diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py new file mode 100644 index 0000000000..7793aed6ee --- /dev/null +++ b/tests/test_state_cacher.py @@ -0,0 +1,68 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from os.path import join, exists +from parameterized import parameterized +from tempfile import gettempdir +import unittest + +import torch + +from monai.utils import StateCacher + +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" + +TEST_CASE_0 = [ + torch.Tensor([1]).to(DEVICE), + {"in_memory": True}, +] +TEST_CASE_1 = [ + torch.Tensor([1]).to(DEVICE), + {"in_memory": False, "cache_dir": gettempdir()}, +] +TEST_CASE_2 = [ + torch.Tensor([1]).to(DEVICE), + {"in_memory": False, "allow_overwrite": False}, +] + +TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2] + +class TestStateCacher(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_state_cacher(self, data_obj, params): + + key = "data_obj" + + state_cacher = StateCacher(**params) + # store it + state_cacher.store(key, data_obj) + # create clone then modify original + data_obj_orig = data_obj.clone() + data_obj += 1 + # Restore and check nothing has changed + data_obj_restored = state_cacher.retrieve(key) + self.assertEqual(data_obj_orig, data_obj_restored) + + # If not allow overwrite, check an attempt would raise exception + if "allow_overwrite" in params and params["allow_overwrite"]: + with self.assertRaises(RuntimeError): + state_cacher.store(key, data_obj) + + # If using a cache dir, check file has been deleted et end + if "cache_dir" in params: + i = id(state_cacher) + del(state_cacher) + self.assertFalse(exists(join(params["cache_dir"], f"state_{key}_{i}.pt"))) + + + +if __name__ == "__main__": + unittest.main() From c38fe77511955249cbfe19d70270823ad255386d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 22 Jan 2021 18:05:48 +0000 Subject: [PATCH 23/25] more info Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/utils/state_cacher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 62aac940aa..8876518935 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -14,7 +14,7 @@ class StateCacher(object): Objects can either be stored in memory or on disk. If stored on disk, they can be stored in a given directory, or alternatively a temporary location will be used. - If necessary, restored objects will be returned to their original device. + If necessary/possible, restored objects will be returned to their original device. Example: From df0c0c2ab715e0be7100c16c17ca1002293dc752 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 25 Jan 2021 13:48:28 +0000 Subject: [PATCH 24/25] autofix changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/optimizers/lr_finder.py | 21 +++++++++++++-------- monai/optimizers/lr_scheduler.py | 2 +- monai/utils/state_cacher.py | 3 ++- tests/test_state_cacher.py | 10 +++++----- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index fa1f1a3991..6d200f3d8a 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -1,9 +1,6 @@ -import copy -import os -from functools import partial -import tempfile -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import warnings +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union import numpy as np import torch @@ -13,11 +10,19 @@ from torch.utils.data import DataLoader from monai.networks.utils import eval_mode -from monai.utils import copy_to_device, optional_import, StateCacher from monai.optimizers.lr_scheduler import ExponentialLR, LinearLR +from monai.utils import StateCacher, copy_to_device, optional_import + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + has_matplotlib = True + import tqdm -plt, has_matplotlib = optional_import("matplotlib.pyplot") -tqdm, has_tqdm = optional_import("tqdm") + has_tqdm = True +else: + plt, has_matplotlib = optional_import("matplotlib.pyplot") + tqdm, has_tqdm = optional_import("tqdm") __all__ = ["LearningRateFinder"] diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index e90b185698..aa9bf2a89b 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -1,13 +1,13 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler - __all__ = ["LinearLR", "ExponentialLR"] class _LRSchedulerMONAI(_LRScheduler): """Base class for increasing the learning rate between two boundaries over a number of iterations""" + def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: """ Args: diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 8876518935..16985b15ff 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -1,8 +1,8 @@ import copy import os - import tempfile from typing import Dict, Optional + import torch __all__ = ["StateCacher"] @@ -22,6 +22,7 @@ class StateCacher(object): >>> state_cacher.store("model", model.state_dict()) >>> model.load_state_dict(state_cacher.retrieve("model")) """ + def __init__( self, in_memory: bool, diff --git a/tests/test_state_cacher.py b/tests/test_state_cacher.py index 7793aed6ee..139e7b8374 100644 --- a/tests/test_state_cacher.py +++ b/tests/test_state_cacher.py @@ -9,12 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from os.path import join, exists -from parameterized import parameterized -from tempfile import gettempdir import unittest +from os.path import exists, join +from tempfile import gettempdir import torch +from parameterized import parameterized from monai.utils import StateCacher @@ -35,6 +35,7 @@ TEST_CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2] + class TestStateCacher(unittest.TestCase): @parameterized.expand(TEST_CASES) def test_state_cacher(self, data_obj, params): @@ -59,10 +60,9 @@ def test_state_cacher(self, data_obj, params): # If using a cache dir, check file has been deleted et end if "cache_dir" in params: i = id(state_cacher) - del(state_cacher) + del state_cacher self.assertFalse(exists(join(params["cache_dir"], f"state_{key}_{i}.pt"))) - if __name__ == "__main__": unittest.main() From 4d2b97a166c64bc616b938644984f90435f8cadd Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 26 Jan 2021 09:25:53 +0000 Subject: [PATCH 25/25] remove object base class Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/handlers/stats_handler.py | 2 +- monai/optimizers/lr_finder.py | 4 ++-- monai/utils/state_cacher.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index c1aef87df0..007fbed413 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -27,7 +27,7 @@ DEFAULT_TAG = "Loss" -class StatsHandler(object): +class StatsHandler: """ StatsHandler defines a set of Ignite Event-handlers for all the log printing logics. It's can be used for any Ignite Engine(trainer, validator and evaluator). diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py index 6d200f3d8a..6ad4132dd0 100644 --- a/monai/optimizers/lr_finder.py +++ b/monai/optimizers/lr_finder.py @@ -27,7 +27,7 @@ __all__ = ["LearningRateFinder"] -class DataLoaderIter(object): +class DataLoaderIter: def __init__(self, data_loader: DataLoader, image_extractor: Callable, label_extractor: Callable) -> None: if not isinstance(data_loader, DataLoader): raise ValueError( @@ -125,7 +125,7 @@ def default_label_extractor(x: Any) -> torch.Tensor: return out -class LearningRateFinder(object): +class LearningRateFinder: """Learning rate range test. The learning rate range test increases the learning rate in a pre-training run diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 16985b15ff..66e9080724 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -8,7 +8,7 @@ __all__ = ["StateCacher"] -class StateCacher(object): +class StateCacher: """Class to cache and retrieve the state of an object. Objects can either be stored in memory or on disk. If stored on disk, they can be