From 88abbe768abbadeacf63aa321cc809c0e46106b9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 17 Dec 2020 17:05:52 +0000 Subject: [PATCH 01/14] unify visualisation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/conf.py | 2 +- docs/source/highlights.md | 10 +- docs/source/index.rst | 2 +- docs/source/metrics.rst | 4 - docs/source/{visualize.rst => visualise.rst} | 10 +- monai/README.md | 2 +- monai/handlers/tensorboard_handlers.py | 2 +- monai/metrics/__init__.py | 1 - monai/metrics/occlusion_sensitivity.py | 225 -------------- monai/{visualize => visualise}/__init__.py | 2 + .../class_activation_maps.py | 166 ++-------- .../img2tensorboard.py | 0 monai/visualise/occlusion_sensitivity.py | 294 ++++++++++++++++++ monai/visualise/visualiser.py | 156 ++++++++++ tests/min_tests.py | 2 +- tests/test_img2tensorboard.py | 2 +- tests/test_integration_segmentation_3d.py | 2 +- ...ivity.py => test_occlusion_sensitivity.py} | 49 ++- tests/test_plot_2d_or_3d_image.py | 2 +- tests/test_vis_cam.py | 2 +- tests/test_vis_gradcam.py | 2 +- tests/test_vis_gradcampp.py | 2 +- 22 files changed, 535 insertions(+), 404 deletions(-) rename docs/source/{visualize.rst => visualise.rst} (51%) delete mode 100644 monai/metrics/occlusion_sensitivity.py rename monai/{visualize => visualise}/__init__.py (79%) rename monai/{visualize => visualise}/class_activation_maps.py (63%) rename monai/{visualize => visualise}/img2tensorboard.py (100%) create mode 100644 monai/visualise/occlusion_sensitivity.py create mode 100644 monai/visualise/visualiser.py rename tests/{test_compute_occlusion_sensitivity.py => test_occlusion_sensitivity.py} (58%) diff --git a/docs/source/conf.py b/docs/source/conf.py index 534193c936..7c2527a569 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -43,7 +43,7 @@ "config", "handlers", "losses", - "visualize", + "visualise", "utils", "inferers", "optimizers", diff --git a/docs/source/highlights.md b/docs/source/highlights.md index d8fe5c2ff9..0ea43ab639 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -21,7 +21,7 @@ The rest of this page provides more details for each module. * [Optimizers](#optimizers) * [Network architectures](#network-architectures) * [Evaluation](#evaluation) -* [Visualization](#visualization) +* [Visualisation](#visualisation) * [Result writing](#result-writing) * [Workflows](#workflows) * [Research](#research) @@ -111,7 +111,7 @@ MONAI also provides post-processing transforms for handling the model outputs. C - Removing segmentation noise based on Connected Component Analysis, as below figure (c). - Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e). -After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms. +After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualise data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms. ![image](../images/post_transforms.png) ### 9. Integrate third-party transforms @@ -237,10 +237,10 @@ Various useful evaluation metrics have been used to measure the quality of medic For example, `Mean Dice` score can be used for segmentation tasks, and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options. -## Visualization -Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). +## Visualisation +Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualise multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualising, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualisation is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). -And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: +And to visualise the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: ![image](../images/cam.png) diff --git a/docs/source/index.rst b/docs/source/index.rst index ea21428e6e..d2b8c7970c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -61,7 +61,7 @@ Technical documentation is available at `docs.monai.io `_ engines inferers handlers - visualize + visualise utils .. toctree:: diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0bcfbd4240..32a3faf380 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -37,7 +37,3 @@ Metrics .. autoclass:: SurfaceDistanceMetric :members: - -`Occlusion sensitivity` ------------------------ -.. autofunction:: compute_occlusion_sensitivity \ No newline at end of file diff --git a/docs/source/visualize.rst b/docs/source/visualise.rst similarity index 51% rename from docs/source/visualize.rst rename to docs/source/visualise.rst index 9668d48114..d61c251710 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualise.rst @@ -1,20 +1,20 @@ :github_url: https://github.com/Project-MONAI/MONAI -.. _visualize: +.. _visualise: -Visualizations +Visualisations ============== -.. currentmodule:: monai.visualize +.. currentmodule:: monai.visualise Tensorboard visuals ------------------- -.. automodule:: monai.visualize.img2tensorboard +.. automodule:: monai.visualise.img2tensorboard :members: Class activation map -------------------- -.. automodule:: monai.visualize.class_activation_maps +.. automodule:: monai.visualise.class_activation_maps :members: \ No newline at end of file diff --git a/monai/README.md b/monai/README.md index 89c1fa3653..18714a8d77 100644 --- a/monai/README.md +++ b/monai/README.md @@ -27,4 +27,4 @@ * **utils**: generic utilities intended to be implemented in pure Python or using Numpy, and not with Pytorch, such as namespace aliasing, auto module loading. -* **visualize**: utilities for data visualization. \ No newline at end of file +* **visualise**: utilities for data visualisation. \ No newline at end of file diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index a9d7d661ec..5d567bd4d1 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -16,7 +16,7 @@ import torch from monai.utils import exact_version, is_scalar, optional_import -from monai.visualize import plot_2d_or_3d_image +from monai.visualise import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index a0d626f45b..f43fb444db 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -12,7 +12,6 @@ from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .hausdorff_distance import * from .meandice import DiceMetric, compute_meandice -from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import * diff --git a/monai/metrics/occlusion_sensitivity.py b/monai/metrics/occlusion_sensitivity.py deleted file mode 100644 index 9879f472a9..0000000000 --- a/monai/metrics/occlusion_sensitivity.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections.abc import Sequence -from functools import partial -from typing import Optional, Union - -import numpy as np -import torch -import torch.nn as nn - -try: - from tqdm import trange - - trange = partial(trange, desc="Computing occlusion sensitivity") -except (ImportError, AttributeError): - trange = range - - -def _check_input_image(image): - """Check that the input image is as expected.""" - # Only accept batch size of 1 - if image.shape[0] > 1: - raise RuntimeError("Expected batch size of 1.") - return image - - -def _check_input_label(label, image): - """Check that the input label is as expected.""" - # If necessary turn the label into a 1-element tensor - if isinstance(label, int): - label = torch.tensor([[label]], dtype=torch.int64).to(image.device) - # If the label is a tensor, make sure there's only 1 element - elif label.numel() != image.shape[0]: - raise RuntimeError("Expected as many labels as batches.") - return label - - -def _check_input_bounding_box(b_box, im_shape): - """Check that the bounding box (if supplied) is as expected.""" - # If no bounding box has been supplied, set min and max to None - if b_box is None: - b_box_min = b_box_max = None - - # Bounding box has been supplied - else: - # Should be twice as many elements in `b_box` as `im_shape` - if len(b_box) != 2 * len(im_shape): - raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") - - # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively. - b_box_min = np.array(b_box[::2]) - b_box_max = np.array(b_box[1::2]) - b_box_min[b_box_min < 0] = 0 - b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 - # Check all max's are < im_shape - if np.any(b_box_max >= im_shape): - raise ValueError("Max bounding box should be < image size for all values") - # Check all min's are <= max's - if np.any(b_box_min > b_box_max): - raise ValueError("Min bounding box should be <= max for all values") - - return b_box_min, b_box_max - - -def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): - """For given number of images, get probability of predicting - a given label. Append to previous evaluations.""" - batch_images = torch.cat(batch_images, dim=0) - batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device) - scores = model(batch_images).detach().gather(1, batch_ids) - return torch.cat((sensitivity_im, scores)) - - -def compute_occlusion_sensitivity( - model: nn.Module, - image: torch.Tensor, - label: Union[int, torch.Tensor], - pad_val: float = 0.0, - margin: Union[int, Sequence] = 2, - n_batch: int = 128, - b_box: Optional[Sequence] = None, - stride: Union[int, Sequence] = 1, - upsample_mode: str = "nearest", -) -> np.ndarray: - """ - This function computes the occlusion sensitivity for a model's prediction - of a given image. By occlusion sensitivity, we mean how the probability of a given - prediction changes as the occluded section of an image changes. This can - be useful to understand why a network is making certain decisions. - - The result is given as ``baseline`` (the probability of - a certain output) minus the probability of the output with the occluded - area. - - Therefore, higher values in the output image mean there was a - greater the drop in certainty, indicating the occluded region was more - important in the decision process. - - See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via - Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 - - Args: - model: classification model to use for inference - image: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. - label: classification label to check for changes (normally the true - label, but doesn't have to be) - pad_val: when occluding part of the image, which values should we put - in the image? - margin: we'll create a cuboid/cube around the voxel to be occluded. if - ``margin==2``, then we'll create a cube that is +/- 2 voxels in - all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` - can be supplied to have a margin of different sizes (i.e., create - a cuboid). - n_batch: number of images in a batch before inference. - b_box: Bounding box on which to perform the analysis. The output image - will also match in size. There should be a minimum and maximum for - all dimensions except batch: ``[min1, max1, min2, max2,...]``. - * By default, the whole image will be used. Decreasing the size will - speed the analysis up, which might be useful for larger images. - * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). - * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. - stride: Stride for performing occlusions. Can be single value or sequence - (for varying stride in the different directions). Should be >= 1. - upsample_mode: If stride != 1 is used, we'll upsample such that the size - of the voxels in the output image match the input. Upsampling is done with - ``torch.nn.Upsample``, and mode can be set to: - * ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear`` - * default is ``nearest``. - Returns: - Numpy array. If no bounding box is supplied, this will be the same size - as the input image. If a bounding box is used, the output image will be - cropped to this size. - """ - - # Check input arguments - image = _check_input_image(image) - label = _check_input_label(label, image) - im_shape = np.array(image.shape[1:]) - b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) - - # Get baseline probability - baseline = model(image).detach()[0, label].item() - - # Create some lists - batch_images = [] - batch_ids = [] - - sensitivity_im = torch.empty(0, dtype=torch.float32, device=image.device) - - # If no bounding box supplied, output shape is same as input shape. - # If bounding box is present, shape is max - min + 1 - output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 - - # Calculate the downsampled shape - if not isinstance(stride, Sequence): - stride_np = np.full_like(im_shape, stride, dtype=np.int32) - stride_np[0] = 1 # always do stride 1 in channel dimension - else: - # Convert to numpy array and check dimensions match - stride_np = np.array(stride, dtype=np.int32) - if stride_np.size != im_shape.size: - raise ValueError("Sizes of image shape and stride should match.") - - # Obviously if stride = 1, downsampled_im_shape == output_im_shape - downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) - downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 - num_required_predictions = np.prod(downsampled_im_shape) - - # Loop 1D over image - for i in trange(num_required_predictions): - # Get corresponding ND index - idx = np.unravel_index(i, downsampled_im_shape) - # Multiply by stride - idx *= stride_np - # If a bounding box is being used, we need to add on - # the min to shift to start of region of interest - if b_box_min is not None: - idx += b_box_min - - # Get min and max index of box to occlude - min_idx = [max(0, i - margin) for i in idx] - max_idx = [min(j, i + margin) for i, j in zip(idx, im_shape)] - - # Clone and replace target area with `pad_val` - occlu_im = image.clone() - occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val - - # Add to list - batch_images.append(occlu_im) - batch_ids.append(label) - - # Once the batch is complete (or on last iteration) - if len(batch_images) == n_batch or i == num_required_predictions - 1: - # Do the predictions and append to sensitivity map - sensitivity_im = _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im) - # Clear lists - batch_images = [] - batch_ids = [] - - # Subtract from baseline - sensitivity_im = baseline - sensitivity_im - - # Reshape to match downsampled image - sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)) - - # If necessary, upsample - if np.any(stride_np != 1): - output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple - upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode) - sensitivity_im = upsampler(sensitivity_im.unsqueeze(0)) - - # Convert tensor to numpy - sensitivity_im = sensitivity_im.cpu().numpy() - - # Squeeze and return - return np.squeeze(sensitivity_im) diff --git a/monai/visualize/__init__.py b/monai/visualise/__init__.py similarity index 79% rename from monai/visualize/__init__.py rename to monai/visualise/__init__.py index 2fbd1dcf66..2b19a262fc 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualise/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .visualiser import ModelWithHooks, NetVisualiser, default_normalizer, default_upsampler # isort:skip from .class_activation_maps import * from .img2tensorboard import * +from .occlusion_sensitivity import OcclusionSensitivity diff --git a/monai/visualize/class_activation_maps.py b/monai/visualise/class_activation_maps.py similarity index 63% rename from monai/visualize/class_activation_maps.py rename to monai/visualise/class_activation_maps.py index 6fd29d1c96..78ef15d510 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualise/class_activation_maps.py @@ -9,138 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from typing import Callable, Dict, Sequence, Union +from typing import Callable, Union -import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -from monai.transforms import ScaleIntensity -from monai.utils import InterpolateMode, ensure_tuple +from monai.visualise import ModelWithHooks, NetVisualiser, default_normalizer, default_upsampler -__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM", "GradCAMpp"] +__all__ = ["CAM", "GradCAM", "GradCAMpp"] -class ModelWithHooks: - """ - A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. - """ - - def __init__( - self, - nn_module, - target_layer_names: Union[str, Sequence[str]], - register_forward: bool = False, - register_backward: bool = False, - ): - """ - - Args: - nn_module: the model to be wrapped. - target_layer_names: the names of the layer to cache. - register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. - register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. - """ - self.model = nn_module - self.target_layers = ensure_tuple(target_layer_names) - - self.gradients: Dict[str, torch.Tensor] = {} - self.activations: Dict[str, torch.Tensor] = {} - self.score = None - self.class_idx = None - self.register_backward = register_backward - self.register_forward = register_forward - - _registered = [] - for name, mod in nn_module.named_modules(): - if name not in self.target_layers: - continue - _registered.append(name) - if self.register_backward: - mod.register_backward_hook(self.backward_hook(name)) - if self.register_forward: - mod.register_forward_hook(self.forward_hook(name)) - if len(_registered) != len(self.target_layers): - warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") - - def backward_hook(self, name): - def _hook(_module, _grad_input, grad_output): - self.gradients[name] = grad_output[0] - - return _hook - - def forward_hook(self, name): - def _hook(_module, _input, output): - self.activations[name] = output - - return _hook - - def get_layer(self, layer_id: Union[str, Callable]): - """ - - Args: - layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`, - this method will return the module `self.model.fc`. - - Returns: - a submodule from self.model. - """ - if callable(layer_id): - return layer_id(self.model) - if isinstance(layer_id, str): - for name, mod in self.model.named_modules(): - if name == layer_id: - return mod - raise NotImplementedError(f"Could not find {layer_id}.") - - def class_score(self, logits, class_idx=None): - if class_idx is not None: - return logits[:, class_idx].squeeze(), class_idx - class_idx = logits.max(1)[-1] - return logits[:, class_idx].squeeze(), class_idx - - def __call__(self, x, class_idx=None, retain_graph=False): - logits = self.model(x) - acti, grad = None, None - if self.register_forward: - acti = tuple(self.activations[layer] for layer in self.target_layers) - if self.register_backward: - score, class_idx = self.class_score(logits, class_idx) - self.model.zero_grad() - self.score, self.class_idx = score, class_idx - score.sum().backward(retain_graph=retain_graph) - grad = tuple(self.gradients[layer] for layer in self.target_layers) - return logits, acti, grad - - -def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: - """ - A linear interpolation method for upsampling the feature map. - The output of this function is a callable `func`, - such that `func(activation_map)` returns an upsampled tensor. - """ - - def up(acti_map): - linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] - interp_mode = linear_mode[len(spatial_size) - 1] - return F.interpolate(acti_map, size=spatial_size, mode=str(interp_mode.value), align_corners=False) - - return up - - -def default_normalizer(acti_map) -> np.ndarray: - """ - A linear intensity scaling by mapping the (min, max) to (1, 0). - """ - if isinstance(acti_map, torch.Tensor): - acti_map = acti_map.detach().cpu().numpy() - scaler = ScaleIntensity(minv=1.0, maxv=0.0) - acti_map = [scaler(x) for x in acti_map] - return np.stack(acti_map, axis=0) - - -class CAM: +class CAM(NetVisualiser): """ Compute class activation map from the last fully-connected layers before the spatial pooling. @@ -150,7 +30,7 @@ class CAM: # densenet 2d from monai.networks.nets import densenet121 - from monai.visualize import CAM + from monai.visualise import CAM model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out") @@ -158,7 +38,7 @@ class CAM: # resnet 2d from monai.networks.nets import se_resnet50 - from monai.visualize import CAM + from monai.visualise import CAM model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear") @@ -166,34 +46,40 @@ class CAM: See Also: - - :py:class:`monai.visualize.class_activation_maps.GradCAM` + - :py:class:`monai.visualise.class_activation_maps.GradCAM` """ def __init__( self, - nn_module, + nn_module: nn.Module, target_layers: str, fc_layers: Union[str, Callable] = "fc", - upsampler=default_upsampler, + upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, - ): + ) -> None: """ - Args: nn_module: the model to be visualised target_layers: name of the model layer to generate the feature map. fc_layers: a string or a callable used to get fully-connected weights to compute activation map from the target_layers (without pooling). and evaluate it at every spatial location. - upsampler: an upsampling method to upsample the feature map. - postprocessing: a callable that applies on the upsampled feature map. + upsampler: An upsampling method to upsample the output image. Default is + N dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + postprocessing: a callable that applies on the upsampled output image. + default is normalising between 0 and 1. """ if not isinstance(nn_module, ModelWithHooks): self.net = ModelWithHooks(nn_module, target_layers, register_forward=True) else: self.net = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing + + super().__init__( + nn_module=nn_module, + upsampler=upsampler, + postprocessing=postprocessing, + ) self.fc_layers = fc_layers def compute_map(self, x, class_idx=None, layer_idx=-1): @@ -262,7 +148,7 @@ class GradCAM: # densenet 2d from monai.networks.nets import densenet121 - from monai.visualize import GradCAM + from monai.visualise import GradCAM model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu") @@ -270,7 +156,7 @@ class GradCAM: # resnet 2d from monai.networks.nets import se_resnet50 - from monai.visualize import GradCAM + from monai.visualise import GradCAM model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) cam = GradCAM(nn_module=model_2d, target_layers="layer4") @@ -278,7 +164,7 @@ class GradCAM: See Also: - - :py:class:`monai.visualize.class_activation_maps.CAM` + - :py:class:`monai.visualise.class_activation_maps.CAM` """ @@ -357,7 +243,7 @@ class GradCAMpp(GradCAM): See Also: - - :py:class:`monai.visualize.class_activation_maps.GradCAM` + - :py:class:`monai.visualise.class_activation_maps.GradCAM` """ diff --git a/monai/visualize/img2tensorboard.py b/monai/visualise/img2tensorboard.py similarity index 100% rename from monai/visualize/img2tensorboard.py rename to monai/visualise/img2tensorboard.py diff --git a/monai/visualise/occlusion_sensitivity.py b/monai/visualise/occlusion_sensitivity.py new file mode 100644 index 0000000000..1ee4e9c089 --- /dev/null +++ b/monai/visualise/occlusion_sensitivity.py @@ -0,0 +1,294 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from functools import partial +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from monai.visualise import NetVisualiser, default_normalizer, default_upsampler + +try: + from tqdm import trange + + trange = partial(trange, desc="Computing occlusion sensitivity") +except (ImportError, AttributeError): + trange = range + + +def _check_input_image(image): + """Check that the input image is as expected.""" + # Only accept batch size of 1 + if image.shape[0] > 1: + raise RuntimeError("Expected batch size of 1.") + return image + + +def _check_input_label(model, label, image): + """Check that the input label is as expected.""" + if label is None: + label = model(image).argmax(1) + # If necessary turn the label into a 1-element tensor + elif not isinstance(label, torch.Tensor): + label = torch.tensor([[label]], dtype=torch.int64).to(image.device) + # make sure there's only 1 element + if label.numel() != image.shape[0]: + raise RuntimeError("Expected as many labels as batches.") + return label + + +def _check_input_bounding_box(b_box, im_shape): + """Check that the bounding box (if supplied) is as expected.""" + # If no bounding box has been supplied, set min and max to None + if b_box is None: + b_box_min = b_box_max = None + + # Bounding box has been supplied + else: + # Should be twice as many elements in `b_box` as `im_shape` + if len(b_box) != 2 * len(im_shape): + raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") + + # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively. + b_box_min = np.array(b_box[::2]) + b_box_max = np.array(b_box[1::2]) + b_box_min[b_box_min < 0] = 0 + b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 + # Check all max's are < im_shape + if np.any(b_box_max >= im_shape): + raise ValueError("Max bounding box should be < image size for all values") + # Check all min's are <= max's + if np.any(b_box_min > b_box_max): + raise ValueError("Min bounding box should be <= max for all values") + + return b_box_min, b_box_max + + +def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): + """For given number of images, get probability of predicting + a given label. Append to previous evaluations.""" + batch_images = torch.cat(batch_images, dim=0) + batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device) + scores = model(batch_images).detach().gather(1, batch_ids) + return torch.cat((sensitivity_im, scores)) + + +class OcclusionSensitivity(NetVisualiser): + """ + This class computes the occlusion sensitivity for a model's prediction + of a given image. By occlusion sensitivity, we mean how the probability of a given + prediction changes as the occluded section of an image changes. This can + be useful to understand why a network is making certain decisions. + + The result is given as ``baseline`` (the probability of + a certain output) minus the probability of the output with the occluded + area. + + Therefore, higher values in the output image mean there was a + greater the drop in certainty, indicating the occluded region was more + important in the decision process. + + See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via + Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 + + Examples + + .. code-block:: python + + # densenet 2d + from monai.networks.nets import densenet121 + from monai.visualise import OcclusionSensitivity + + model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + occ_sens = OcclusionSensitivity(nn_module=model_2d) + result = occ_sens(x=torch.rand((1, 1, 48, 64)), class_idx=None, b_box=[-1, -1, 2, 40, 1, 62]) + + # densenet 3d + from monai.networks.nets import DenseNet + from monai.visualise import OcclusionSensitivity + + model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) + occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=2) + result = occ_sens(torch.rand(1, 1, 6, 6, 6), class_idx=1, b_box=[-1, -1, 2, 3, -1, -1, -1, -1]) + + See Also: + + - :py:class:`monai.visualise.occlusion_sensitivity.OcclusionSensitivity.` + """ + + def __init__( + self, + nn_module: nn.Module, + pad_val: float = 0.0, + margin: Union[int, Sequence] = 2, + n_batch: int = 128, + stride: Union[int, Sequence] = 1, + upsampler: Callable = default_upsampler, + postprocessing: Callable = default_normalizer, + ) -> None: + """ + Args: + nn_module: classification model to use for inference + pad_val: when occluding part of the image, which values should we put + in the image? + margin: we'll create a cuboid/cube around the voxel to be occluded. if + ``margin==2``, then we'll create a cube that is +/- 2 voxels in + all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` + can be supplied to have a margin of different sizes (i.e., create + a cuboid). + n_batch: number of images in a batch before inference. + b_box: Bounding box on which to perform the analysis. The output image + will also match in size. There should be a minimum and maximum for + all dimensions except batch: ``[min1, max1, min2, max2,...]``. + * By default, the whole image will be used. Decreasing the size will + speed the analysis up, which might be useful for larger images. + * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). + * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + stride: Stride in spatial directions for performing occlusions. Can be single + value or sequence (for varying stride in the different directions). + Should be >= 1. Striding in the channel direction will always be 1. + upsampler: An upsampling method to upsample the output image. Default is + N dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + postprocessing: a callable that applies on the upsampled output image. + default is normalising between 0 and 1. + """ + + super().__init__( + nn_module=nn_module, + upsampler=upsampler, + postprocessing=postprocessing, + ) + + self.pad_val = pad_val + self.margin = margin + self.n_batch = n_batch + self.stride = stride + + def _compute_occlusion_sensitivity(self, x, class_idx, b_box): + + # Get bounding box + im_shape = np.array(x.shape[1:]) + b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) + + # Get baseline probability + baseline = self.nn_module(x).detach()[0, class_idx].item() + + # Create some lists + batch_images = [] + batch_ids = [] + + sensitivity_im = torch.empty(0, dtype=torch.float32, device=x.device) + + # If no bounding box supplied, output shape is same as input shape. + # If bounding box is present, shape is max - min + 1 + output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 + + # Calculate the downsampled shape + if not isinstance(self.stride, Sequence): + stride_np = np.full_like(im_shape, self.stride, dtype=np.int32) + stride_np[0] = 1 # always do stride 1 in channel dimension + else: + # Convert to numpy array and check dimensions match + stride_np = np.array(self.stride, dtype=np.int32) + if stride_np.size != im_shape - 1: # should be 1 less to get spatial dimensions + raise ValueError( + "If supplying stride as sequence, number of elements of stride should match number of spatial dimensions." + ) + + # Obviously if stride = 1, downsampled_im_shape == output_im_shape + downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) + downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 + num_required_predictions = np.prod(downsampled_im_shape) + + # Loop 1D over image + for i in trange(num_required_predictions): + # Get corresponding ND index + idx = np.unravel_index(i, downsampled_im_shape) + # Multiply by stride + idx *= stride_np + # If a bounding box is being used, we need to add on + # the min to shift to start of region of interest + if b_box_min is not None: + idx += b_box_min + + # Get min and max index of box to occlude + min_idx = [max(0, i - self.margin) for i in idx] + max_idx = [min(j, i + self.margin) for i, j in zip(idx, im_shape)] + + # Clone and replace target area with `pad_val` + occlu_im = x.clone() + occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val + + # Add to list + batch_images.append(occlu_im) + batch_ids.append(class_idx) + + # Once the batch is complete (or on last iteration) + if len(batch_images) == self.n_batch or i == num_required_predictions - 1: + # Do the predictions and append to sensitivity map + sensitivity_im = _append_to_sensitivity_im(self.nn_module, batch_images, batch_ids, sensitivity_im) + # Clear lists + batch_images = [] + batch_ids = [] + + # Subtract baseline from sensitivity so that +ve values mean more important in decision process + sensitivity_im = baseline - sensitivity_im + + # Reshape to match downsampled image, and unsqueeze to add batch dimension back in + sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)).unsqueeze(0) + + return sensitivity_im, output_im_shape + + def __call__( # type: ignore + self, x: torch.Tensor, class_idx: Optional[Union[int, torch.Tensor]] = None, b_box: Optional[Sequence] = None + ): + """ + Args: + x: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. + class_idx: classification label to check for changes. This could be the true + label, or it could be the predicted label, etc. Use ``None`` to use generate + the predicted model. + b_box: Bounding box on which to perform the analysis. The output image + will also match in size. There should be a minimum and maximum for + all dimensions except batch: ``[min1, max1, min2, max2,...]``. + * By default, the whole image will be used. Decreasing the size will + speed the analysis up, which might be useful for larger images. + * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). + * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + Returns: + Depends on the postprocessing, but the default return type is a Numpy array. + The returned image will occupy the same space as the input image, unless a + bounding box is supplied, in which case it will occupy that space. Unless + upsampling is disabled, the output image will have voxels of the same size + as the input image. + """ + + # Check input arguments + x = _check_input_image(x) + class_idx = _check_input_label(self.nn_module, class_idx, x) + + # Generate sensitivity image + sensitivity_im, output_im_shape = self._compute_occlusion_sensitivity(x, class_idx, b_box) + + # upsampling and postprocessing + if self.upsampler is not None: + if np.any(output_im_shape != x.shape[1:]): + img_spatial = tuple(output_im_shape[1:]) + sensitivity_im = self.upsampler(img_spatial)(sensitivity_im) + if self.postprocessing: + sensitivity_im = self.postprocessing(sensitivity_im) + + # Squeeze and return + return sensitivity_im diff --git a/monai/visualise/visualiser.py b/monai/visualise/visualiser.py new file mode 100644 index 0000000000..c5e46e4315 --- /dev/null +++ b/monai/visualise/visualiser.py @@ -0,0 +1,156 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from abc import ABC +from typing import Callable, Dict, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from monai.transforms import ScaleIntensity +from monai.utils import InterpolateMode, ensure_tuple + +__all__ = ["default_upsampler", "default_normalizer", "ModelWithHooks", "NetVisualiser"] + + +def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: + """ + A linear interpolation method for upsampling the feature map. + The output of this function is a callable `func`, + such that `func(x)` returns an upsampled tensor. + """ + + def up(x): + linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] + interp_mode = linear_mode[len(spatial_size) - 1] + return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=False) + + return up + + +def default_normalizer(x) -> np.ndarray: + """ + A linear intensity scaling by mapping the (min, max) to (1, 0). + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + scaler = ScaleIntensity(minv=1.0, maxv=0.0) + x = [scaler(x) for x in x] + return np.stack(x, axis=0) + + +class ModelWithHooks: + """ + A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. + """ + + def __init__( + self, + nn_module, + target_layer_names: Union[str, Sequence[str]], + register_forward: bool = False, + register_backward: bool = False, + ): + """ + + Args: + nn_module: the model to be wrapped. + target_layer_names: the names of the layer to cache. + register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. + register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. + """ + self.model = nn_module + self.target_layers = ensure_tuple(target_layer_names) + + self.gradients: Dict[str, torch.Tensor] = {} + self.activations: Dict[str, torch.Tensor] = {} + self.score = None + self.class_idx = None + self.register_backward = register_backward + self.register_forward = register_forward + + _registered = [] + for name, mod in nn_module.named_modules(): + if name not in self.target_layers: + continue + _registered.append(name) + if self.register_backward: + mod.register_backward_hook(self.backward_hook(name)) + if self.register_forward: + mod.register_forward_hook(self.forward_hook(name)) + if len(_registered) != len(self.target_layers): + warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") + + def backward_hook(self, name): + def _hook(_module, _grad_input, grad_output): + self.gradients[name] = grad_output[0] + + return _hook + + def forward_hook(self, name): + def _hook(_module, _input, output): + self.activations[name] = output + + return _hook + + def get_layer(self, layer_id: Union[str, Callable]): + """ + + Args: + layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`, + this method will return the module `self.model.fc`. + + Returns: + a submodule from self.model. + """ + if callable(layer_id): + return layer_id(self.model) + if isinstance(layer_id, str): + for name, mod in self.model.named_modules(): + if name == layer_id: + return mod + raise NotImplementedError(f"Could not find {layer_id}.") + + def class_score(self, logits, class_idx=None): + if class_idx is not None: + return logits[:, class_idx].squeeze(), class_idx + class_idx = logits.max(1)[-1] + return logits[:, class_idx].squeeze(), class_idx + + def __call__(self, x, class_idx=None, retain_graph=False): + logits = self.model(x) + acti, grad = None, None + if self.register_forward: + acti = tuple(self.activations[layer] for layer in self.target_layers) + if self.register_backward: + score, class_idx = self.class_score(logits, class_idx) + self.model.zero_grad() + self.score, self.class_idx = score, class_idx + score.sum().backward(retain_graph=retain_graph) + grad = tuple(self.gradients[layer] for layer in self.target_layers) + return logits, acti, grad + + +class NetVisualiser(ABC): + def __init__( + self, + nn_module: torch.nn.Module, + upsampler: Callable, + postprocessing: Callable, + ) -> None: + self.nn_module = nn_module + self.upsampler = upsampler + self.postprocessing = postprocessing + + def __call__(self): + raise NotImplementedError() diff --git a/tests/min_tests.py b/tests/min_tests.py index ccfc789992..510e201b94 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -101,7 +101,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", - "test_compute_occlusion_sensitivity", + "test_occlusion_sensitivity", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index 99761b4d11..5625038dba 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -15,7 +15,7 @@ import tensorboard import torch -from monai.visualize import make_animated_gif_summary +from monai.visualise import make_animated_gif_summary class TestImg2Tensorboard(unittest.TestCase): diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 9de7dcf362..eb72024dfa 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -38,7 +38,7 @@ ToTensord, ) from monai.utils import set_determinism -from monai.visualize import plot_2d_or_3d_image +from monai.visualise import plot_2d_or_3d_image from tests.testing_data.integration_answers import test_integration_value from tests.utils import DistTestCase, TimedCall, skip_if_quick diff --git a/tests/test_compute_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py similarity index 58% rename from tests/test_compute_occlusion_sensitivity.py rename to tests/test_occlusion_sensitivity.py index 9f30162c47..3640920057 100644 --- a/tests/test_compute_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -14,8 +14,8 @@ import torch from parameterized import parameterized -from monai.metrics import compute_occlusion_sensitivity from monai.networks.nets import DenseNet, densenet121 +from monai.visualise import OcclusionSensitivity device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3).to(device) @@ -28,33 +28,56 @@ # 2D w/ bounding box TEST_CASE_0 = [ { - "model": model_2d, - "image": torch.rand(1, 1, 48, 64).to(device), - "label": torch.tensor([[0]], dtype=torch.int64).to(device), + "nn_module": model_2d, + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + "class_idx": torch.tensor([[0]], dtype=torch.int64).to(device), "b_box": [-1, -1, 2, 40, 1, 62], }, - (39, 62), + (1, 1, 39, 62), ] -# 3D w/ bounding box +# 3D w/ bounding box and stride TEST_CASE_1 = [ { - "model": model_3d, - "image": torch.rand(1, 1, 6, 6, 6).to(device), - "label": 0, - "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], + "nn_module": model_3d, "n_batch": 10, "stride": 2, }, - (2, 6, 6), + { + "x": torch.rand(1, 1, 6, 6, 6).to(device), + "class_idx": None, + "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], + }, + (1, 1, 2, 6, 6), +] + +TEST_CASE_FAIL = [ # 2D should fail, since 3 stride values given + { + "nn_module": model_2d, + "n_batch": 10, + "stride": (2, 2, 2), + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + "class_idx": None, + "b_box": [-1, -1, 2, 3, -1, -1], + }, ] class TestComputeOcclusionSensitivity(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) - def test_shape(self, input_data, expected_shape): - result = compute_occlusion_sensitivity(**input_data) + def test_shape(self, init_data, call_data, expected_shape): + occ_sens = OcclusionSensitivity(**init_data) + result = occ_sens(**call_data) self.assertTupleEqual(result.shape, expected_shape) + def test_fail(self): + occ_sens = OcclusionSensitivity(**TEST_CASE_FAIL[0]) + with self.assertRaises(ValueError): + occ_sens(**TEST_CASE_FAIL[1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index a6ca41b4bc..9610405545 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -17,7 +17,7 @@ from parameterized import parameterized from torch.utils.tensorboard import SummaryWriter -from monai.visualize import plot_2d_or_3d_image +from monai.visualise import plot_2d_or_3d_image TEST_CASE_1 = [(1, 1, 10, 10)] diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index e2ec119ec8..414a61e46d 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121, se_resnet50 -from monai.visualize import CAM +from monai.visualise import CAM # 2D TEST_CASE_0 = [ diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 3fb53b1fda..60f6d4f87b 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121, se_resnet50 -from monai.visualize import GradCAM +from monai.visualise import GradCAM # 2D TEST_CASE_0 = [ diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py index c6bdef1647..5d404b2486 100644 --- a/tests/test_vis_gradcampp.py +++ b/tests/test_vis_gradcampp.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121, se_resnet50 -from monai.visualize import GradCAMpp +from monai.visualise import GradCAMpp # 2D TEST_CASE_0 = [ From 5082bdd45ac3ab304440be6839c4d38704650ace Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 17 Dec 2020 17:21:46 +0000 Subject: [PATCH 02/14] visualis* -> visualiz* Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/conf.py | 2 +- docs/source/highlights.md | 10 +++---- docs/source/index.rst | 2 +- docs/source/visualise.rst | 20 -------------- docs/source/visualize.rst | 26 +++++++++++++++++++ monai/README.md | 2 +- monai/handlers/tensorboard_handlers.py | 4 +-- monai/{visualise => visualize}/__init__.py | 2 +- .../class_activation_maps.py | 26 +++++++++---------- .../img2tensorboard.py | 0 .../occlusion_sensitivity.py | 10 +++---- .../visualiser.py => visualize/visualizer.py} | 4 +-- tests/test_img2tensorboard.py | 2 +- tests/test_integration_segmentation_3d.py | 2 +- tests/test_occlusion_sensitivity.py | 2 +- tests/test_plot_2d_or_3d_image.py | 2 +- tests/test_vis_cam.py | 2 +- tests/test_vis_gradcam.py | 2 +- tests/test_vis_gradcampp.py | 2 +- 19 files changed, 64 insertions(+), 58 deletions(-) delete mode 100644 docs/source/visualise.rst create mode 100644 docs/source/visualize.rst rename monai/{visualise => visualize}/__init__.py (92%) rename monai/{visualise => visualize}/class_activation_maps.py (93%) rename monai/{visualise => visualize}/img2tensorboard.py (100%) rename monai/{visualise => visualize}/occlusion_sensitivity.py (97%) rename monai/{visualise/visualiser.py => visualize/visualizer.py} (99%) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7c2527a569..534193c936 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -43,7 +43,7 @@ "config", "handlers", "losses", - "visualise", + "visualize", "utils", "inferers", "optimizers", diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 0ea43ab639..f36358db12 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -21,7 +21,7 @@ The rest of this page provides more details for each module. * [Optimizers](#optimizers) * [Network architectures](#network-architectures) * [Evaluation](#evaluation) -* [Visualisation](#visualisation) +* [Vizualisation](#vizualisation) * [Result writing](#result-writing) * [Workflows](#workflows) * [Research](#research) @@ -111,7 +111,7 @@ MONAI also provides post-processing transforms for handling the model outputs. C - Removing segmentation noise based on Connected Component Analysis, as below figure (c). - Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e). -After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualise data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms. +After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms. ![image](../images/post_transforms.png) ### 9. Integrate third-party transforms @@ -237,10 +237,10 @@ Various useful evaluation metrics have been used to measure the quality of medic For example, `Mean Dice` score can be used for segmentation tasks, and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options. -## Visualisation -Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualise multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualising, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualisation is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). +## Visualization +Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by viualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). -And to visualise the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: +And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: ![image](../images/cam.png) diff --git a/docs/source/index.rst b/docs/source/index.rst index d2b8c7970c..ea21428e6e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -61,7 +61,7 @@ Technical documentation is available at `docs.monai.io `_ engines inferers handlers - visualise + visualize utils .. toctree:: diff --git a/docs/source/visualise.rst b/docs/source/visualise.rst deleted file mode 100644 index d61c251710..0000000000 --- a/docs/source/visualise.rst +++ /dev/null @@ -1,20 +0,0 @@ -:github_url: https://github.com/Project-MONAI/MONAI - -.. _visualise: - -Visualisations -============== - -.. currentmodule:: monai.visualise - -Tensorboard visuals -------------------- - -.. automodule:: monai.visualise.img2tensorboard - :members: - -Class activation map --------------------- - -.. automodule:: monai.visualise.class_activation_maps - :members: \ No newline at end of file diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst new file mode 100644 index 0000000000..5c2ac36a1a --- /dev/null +++ b/docs/source/visualize.rst @@ -0,0 +1,26 @@ +:github_url: https://github.com/Project-MONAI/MONAI + +.. _visualize: + +Visualizations +============== + +.. currentmodule:: monai.visualize + +Tensorboard visuals +------------------- + +.. automodule:: monai.visualize.img2tensorboard + :members: + +Class activation map +-------------------- + +.. automodule:: monai.visualize.class_activation_maps + :members: + +Occlusion sensitivity +--------------------- + +.. automodule:: monai.visualize.OcclusionSensitivity + :members: diff --git a/monai/README.md b/monai/README.md index 18714a8d77..89c1fa3653 100644 --- a/monai/README.md +++ b/monai/README.md @@ -27,4 +27,4 @@ * **utils**: generic utilities intended to be implemented in pure Python or using Numpy, and not with Pytorch, such as namespace aliasing, auto module loading. -* **visualise**: utilities for data visualisation. \ No newline at end of file +* **visualize**: utilities for data visualization. \ No newline at end of file diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 5d567bd4d1..b9697d008f 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -16,7 +16,7 @@ import torch from monai.utils import exact_version, is_scalar, optional_import -from monai.visualise import plot_2d_or_3d_image +from monai.visualize import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: @@ -174,7 +174,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No class TensorBoardImageHandler(object): """ - TensorBoardImageHandler is an Ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images. 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). diff --git a/monai/visualise/__init__.py b/monai/visualize/__init__.py similarity index 92% rename from monai/visualise/__init__.py rename to monai/visualize/__init__.py index 2b19a262fc..20b4574794 100644 --- a/monai/visualise/__init__.py +++ b/monai/visualize/__init__.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .visualiser import ModelWithHooks, NetVisualiser, default_normalizer, default_upsampler # isort:skip +from .visualizer import ModelWithHooks, NetVisualizer, default_normalizer, default_upsampler # isort:skip from .class_activation_maps import * from .img2tensorboard import * from .occlusion_sensitivity import OcclusionSensitivity diff --git a/monai/visualise/class_activation_maps.py b/monai/visualize/class_activation_maps.py similarity index 93% rename from monai/visualise/class_activation_maps.py rename to monai/visualize/class_activation_maps.py index 78ef15d510..727b12b660 100644 --- a/monai/visualise/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -15,12 +15,12 @@ import torch.nn as nn import torch.nn.functional as F -from monai.visualise import ModelWithHooks, NetVisualiser, default_normalizer, default_upsampler +from monai.visualize import ModelWithHooks, NetVisualizer, default_normalizer, default_upsampler __all__ = ["CAM", "GradCAM", "GradCAMpp"] -class CAM(NetVisualiser): +class CAM(NetVisualizer): """ Compute class activation map from the last fully-connected layers before the spatial pooling. @@ -30,7 +30,7 @@ class CAM(NetVisualiser): # densenet 2d from monai.networks.nets import densenet121 - from monai.visualise import CAM + from monai.visualize import CAM model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out") @@ -38,7 +38,7 @@ class CAM(NetVisualiser): # resnet 2d from monai.networks.nets import se_resnet50 - from monai.visualise import CAM + from monai.visualize import CAM model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear") @@ -46,7 +46,7 @@ class CAM(NetVisualiser): See Also: - - :py:class:`monai.visualise.class_activation_maps.GradCAM` + - :py:class:`monai.visualize.class_activation_maps.GradCAM` """ @@ -60,7 +60,7 @@ def __init__( ) -> None: """ Args: - nn_module: the model to be visualised + nn_module: the model to be visualized target_layers: name of the model layer to generate the feature map. fc_layers: a string or a callable used to get fully-connected weights to compute activation map from the target_layers (without pooling). and evaluate it at every spatial location. @@ -117,7 +117,7 @@ def __call__(self, x, class_idx=None, layer_idx=-1): Args: x: input tensor, shape must be compatible with `nn_module`. - class_idx: index of the class to be visualised. Default to argmax(logits) + class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. Returns: @@ -148,7 +148,7 @@ class GradCAM: # densenet 2d from monai.networks.nets import densenet121 - from monai.visualise import GradCAM + from monai.visualize import GradCAM model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu") @@ -156,7 +156,7 @@ class GradCAM: # resnet 2d from monai.networks.nets import se_resnet50 - from monai.visualise import GradCAM + from monai.visualize import GradCAM model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) cam = GradCAM(nn_module=model_2d, target_layers="layer4") @@ -164,7 +164,7 @@ class GradCAM: See Also: - - :py:class:`monai.visualise.class_activation_maps.CAM` + - :py:class:`monai.visualize.class_activation_maps.CAM` """ @@ -172,7 +172,7 @@ def __init__(self, nn_module, target_layers: str, upsampler=default_upsampler, p """ Args: - nn_module: the model to be used to generate the visualisations. + nn_module: the model to be used to generate the visualizations. target_layers: name of the model layer to generate the feature map. upsampler: an upsampling method to upsample the feature map. postprocessing: a callable that applies on the upsampled feature map. @@ -215,7 +215,7 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): Args: x: input tensor, shape must be compatible with `nn_module`. - class_idx: index of the class to be visualised. Default to argmax(logits) + class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. retain_graph: whether to retain_graph for torch module backward call. @@ -243,7 +243,7 @@ class GradCAMpp(GradCAM): See Also: - - :py:class:`monai.visualise.class_activation_maps.GradCAM` + - :py:class:`monai.visualize.class_activation_maps.GradCAM` """ diff --git a/monai/visualise/img2tensorboard.py b/monai/visualize/img2tensorboard.py similarity index 100% rename from monai/visualise/img2tensorboard.py rename to monai/visualize/img2tensorboard.py diff --git a/monai/visualise/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py similarity index 97% rename from monai/visualise/occlusion_sensitivity.py rename to monai/visualize/occlusion_sensitivity.py index 1ee4e9c089..1ae2334d41 100644 --- a/monai/visualise/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from monai.visualise import NetVisualiser, default_normalizer, default_upsampler +from monai.visualize import NetVisualizer, default_normalizer, default_upsampler try: from tqdm import trange @@ -84,7 +84,7 @@ def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): return torch.cat((sensitivity_im, scores)) -class OcclusionSensitivity(NetVisualiser): +class OcclusionSensitivity(NetVisualizer): """ This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity, we mean how the probability of a given @@ -108,7 +108,7 @@ class OcclusionSensitivity(NetVisualiser): # densenet 2d from monai.networks.nets import densenet121 - from monai.visualise import OcclusionSensitivity + from monai.visualize import OcclusionSensitivity model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) occ_sens = OcclusionSensitivity(nn_module=model_2d) @@ -116,7 +116,7 @@ class OcclusionSensitivity(NetVisualiser): # densenet 3d from monai.networks.nets import DenseNet - from monai.visualise import OcclusionSensitivity + from monai.visualize import OcclusionSensitivity model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=2) @@ -124,7 +124,7 @@ class OcclusionSensitivity(NetVisualiser): See Also: - - :py:class:`monai.visualise.occlusion_sensitivity.OcclusionSensitivity.` + - :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.` """ def __init__( diff --git a/monai/visualise/visualiser.py b/monai/visualize/visualizer.py similarity index 99% rename from monai/visualise/visualiser.py rename to monai/visualize/visualizer.py index c5e46e4315..62db5d26ee 100644 --- a/monai/visualise/visualiser.py +++ b/monai/visualize/visualizer.py @@ -20,7 +20,7 @@ from monai.transforms import ScaleIntensity from monai.utils import InterpolateMode, ensure_tuple -__all__ = ["default_upsampler", "default_normalizer", "ModelWithHooks", "NetVisualiser"] +__all__ = ["default_upsampler", "default_normalizer", "ModelWithHooks", "NetVisualizer"] def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: @@ -141,7 +141,7 @@ def __call__(self, x, class_idx=None, retain_graph=False): return logits, acti, grad -class NetVisualiser(ABC): +class NetVisualizer(ABC): def __init__( self, nn_module: torch.nn.Module, diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index 5625038dba..99761b4d11 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -15,7 +15,7 @@ import tensorboard import torch -from monai.visualise import make_animated_gif_summary +from monai.visualize import make_animated_gif_summary class TestImg2Tensorboard(unittest.TestCase): diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index eb72024dfa..9de7dcf362 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -38,7 +38,7 @@ ToTensord, ) from monai.utils import set_determinism -from monai.visualise import plot_2d_or_3d_image +from monai.visualize import plot_2d_or_3d_image from tests.testing_data.integration_answers import test_integration_value from tests.utils import DistTestCase, TimedCall, skip_if_quick diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py index 3640920057..9f5dc44776 100644 --- a/tests/test_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121 -from monai.visualise import OcclusionSensitivity +from monai.visualize import OcclusionSensitivity device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3).to(device) diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index 9610405545..a6ca41b4bc 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -17,7 +17,7 @@ from parameterized import parameterized from torch.utils.tensorboard import SummaryWriter -from monai.visualise import plot_2d_or_3d_image +from monai.visualize import plot_2d_or_3d_image TEST_CASE_1 = [(1, 1, 10, 10)] diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index 414a61e46d..e2ec119ec8 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121, se_resnet50 -from monai.visualise import CAM +from monai.visualize import CAM # 2D TEST_CASE_0 = [ diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index 60f6d4f87b..3fb53b1fda 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121, se_resnet50 -from monai.visualise import GradCAM +from monai.visualize import GradCAM # 2D TEST_CASE_0 = [ diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py index 5d404b2486..c6bdef1647 100644 --- a/tests/test_vis_gradcampp.py +++ b/tests/test_vis_gradcampp.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet, densenet121, se_resnet50 -from monai.visualise import GradCAMpp +from monai.visualize import GradCAMpp # 2D TEST_CASE_0 = [ From f73e737f7569f71fc42ef4357a70b598dae2ffbb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 17 Dec 2020 17:23:25 +0000 Subject: [PATCH 03/14] typos Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/highlights.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/highlights.md b/docs/source/highlights.md index f36358db12..41625aaac7 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -21,7 +21,7 @@ The rest of this page provides more details for each module. * [Optimizers](#optimizers) * [Network architectures](#network-architectures) * [Evaluation](#evaluation) -* [Vizualisation](#vizualisation) +* [Visualiz`ation](#visualization) * [Result writing](#result-writing) * [Workflows](#workflows) * [Research](#research) @@ -238,7 +238,7 @@ Various useful evaluation metrics have been used to measure the quality of medic For example, `Mean Dice` score can be used for segmentation tasks, and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options. ## Visualization -Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by viualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). +Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models: From 20ec6fb9a68b52c09a86e5661765170237a4c62f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 17 Dec 2020 17:24:11 +0000 Subject: [PATCH 04/14] typo Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/highlights.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 41625aaac7..d8fe5c2ff9 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -21,7 +21,7 @@ The rest of this page provides more details for each module. * [Optimizers](#optimizers) * [Network architectures](#network-architectures) * [Evaluation](#evaluation) -* [Visualiz`ation](#visualization) +* [Visualization](#visualization) * [Result writing](#result-writing) * [Workflows](#workflows) * [Research](#research) From 77659ac0c521ca9938d789d0f83efa22581bd116 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 18 Dec 2020 10:51:39 +0000 Subject: [PATCH 05/14] fix documentation Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- docs/source/visualize.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst index 5c2ac36a1a..850fd51770 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualize.rst @@ -22,5 +22,5 @@ Class activation map Occlusion sensitivity --------------------- -.. automodule:: monai.visualize.OcclusionSensitivity +.. automodule:: monai.visualize.occlusion_sensitivity :members: From b22fd401065c3a8c0fd4d86145880e2489be568f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 4 Jan 2021 11:18:40 +0000 Subject: [PATCH 06/14] update docstrings Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/occlusion_sensitivity.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 1ae2334d41..15bb88bf85 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -137,31 +137,31 @@ def __init__( upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, ) -> None: - """ - Args: - nn_module: classification model to use for inference - pad_val: when occluding part of the image, which values should we put + """ Occlusion sensitivitiy constructor. + + :param nn_module: classification model to use for inference + :param pad_val: when occluding part of the image, which values should we put in the image? - margin: we'll create a cuboid/cube around the voxel to be occluded. if + :param margin: we'll create a cuboid/cube around the voxel to be occluded. if ``margin==2``, then we'll create a cube that is +/- 2 voxels in all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` can be supplied to have a margin of different sizes (i.e., create a cuboid). - n_batch: number of images in a batch before inference. - b_box: Bounding box on which to perform the analysis. The output image + :param n_batch: number of images in a batch before inference. + :param b_box: Bounding box on which to perform the analysis. The output image will also match in size. There should be a minimum and maximum for all dimensions except batch: ``[min1, max1, min2, max2,...]``. * By default, the whole image will be used. Decreasing the size will speed the analysis up, which might be useful for larger images. * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. - stride: Stride in spatial directions for performing occlusions. Can be single + :param stride: Stride in spatial directions for performing occlusions. Can be single value or sequence (for varying stride in the different directions). Should be >= 1. Striding in the channel direction will always be 1. - upsampler: An upsampling method to upsample the output image. Default is + :param upsampler: An upsampling method to upsample the output image. Default is N dimensional linear (bilinear, trilinear, etc.) depending on num spatial dimensions of input. - postprocessing: a callable that applies on the upsampled output image. + :param postprocessing: a callable that applies on the upsampled output image. default is normalising between 0 and 1. """ From 6ece488382eed32e50eeb4169316d41348deab86 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 4 Jan 2021 11:53:28 +0000 Subject: [PATCH 07/14] docstring change 2 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/occlusion_sensitivity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 15bb88bf85..7cd132c3cf 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -137,7 +137,7 @@ def __init__( upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, ) -> None: - """ Occlusion sensitivitiy constructor. + """Occlusion sensitivitiy constructor. :param nn_module: classification model to use for inference :param pad_val: when occluding part of the image, which values should we put From 128b0a5dd6a9bde400f1203fc9cbfc372fc49a1b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 4 Jan 2021 14:49:01 +0000 Subject: [PATCH 08/14] use with_eval mode and reduce code duplication in cam classes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/class_activation_maps.py | 133 +++++++++++------------ monai/visualize/occlusion_sensitivity.py | 31 +++--- monai/visualize/visualizer.py | 11 +- 3 files changed, 87 insertions(+), 88 deletions(-) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 727b12b660..fa13cf1f87 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -20,7 +20,57 @@ __all__ = ["CAM", "GradCAM", "GradCAMpp"] -class CAM(NetVisualizer): +class CAMBase(NetVisualizer): + """ + Base class for CAM methods. + """ + + def __init__( + self, + nn_module: nn.Module, + target_layers: str, + upsampler: Callable = default_upsampler, + postprocessing: Callable = default_normalizer, + register_backward: bool = True, + ) -> None: + # Convert to model with hooks if necessary + if not isinstance(nn_module, ModelWithHooks): + net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=register_backward) + else: + net = nn_module + + super().__init__( + nn_module=net, + upsampler=upsampler, + postprocessing=postprocessing, + ) + + def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + """ + Computes the actual feature map size given `nn_module` and the target_layer name. + Args: + input_size: shape of the input tensor + device: the device used to initialise the input tensor + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + Returns: + shape of the actual feature map. + """ + return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape + + def _upsample_and_post_process(self, acti_map, x): + # upsampling and postprocessing + if self.upsampler: + img_spatial = x.shape[2:] + acti_map = self.upsampler(img_spatial)(acti_map) + if self.postprocessing: + acti_map = self.postprocessing(acti_map) + return acti_map + + def __call__(self): + raise NotImplementedError() + + +class CAM(CAMBase): """ Compute class activation map from the last fully-connected layers before the spatial pooling. @@ -70,15 +120,12 @@ def __init__( postprocessing: a callable that applies on the upsampled output image. default is normalising between 0 and 1. """ - if not isinstance(nn_module, ModelWithHooks): - self.net = ModelWithHooks(nn_module, target_layers, register_forward=True) - else: - self.net = nn_module - super().__init__( nn_module=nn_module, + target_layers=target_layers, upsampler=upsampler, postprocessing=postprocessing, + register_backward=False, ) self.fc_layers = fc_layers @@ -86,31 +133,17 @@ def compute_map(self, x, class_idx=None, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, _ = self.net(x) + logits, acti, _ = self.nn_module(x) acti = acti[layer_idx] if class_idx is None: class_idx = logits.max(1)[-1] b, c, *spatial = acti.shape acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D - fc_layers = self.net.get_layer(self.fc_layers) + fc_layers = self.nn_module.get_layer(self.fc_layers) output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2) output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0) return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): - """ - Computes the actual feature map size given `nn_module` and the target_layer name. - - Args: - input_size: shape of the input tensor - device: the device used to initialise the input tensor - layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. - - Returns: - shape of the actual feature map. - """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, layer_idx=-1): """ Compute the activation map with upsampling and postprocessing. @@ -124,17 +157,10 @@ def __call__(self, x, class_idx=None, layer_idx=-1): activation maps """ acti_map = self.compute_map(x, class_idx, layer_idx) - - # upsampling and postprocessing - if self.upsampler: - img_spatial = x.shape[2:] - acti_map = self.upsampler(img_spatial)(acti_map) - if self.postprocessing: - acti_map = self.postprocessing(acti_map) - return acti_map + return self._upsample_and_post_process(acti_map, x) -class GradCAM: +class GradCAM(CAMBase): """ Computes Gradient-weighted Class Activation Mapping (Grad-CAM). This implementation is based on: @@ -168,47 +194,17 @@ class GradCAM: """ - def __init__(self, nn_module, target_layers: str, upsampler=default_upsampler, postprocessing=default_normalizer): - """ - - Args: - nn_module: the model to be used to generate the visualizations. - target_layers: name of the model layer to generate the feature map. - upsampler: an upsampling method to upsample the feature map. - postprocessing: a callable that applies on the upsampled feature map. - """ - if not isinstance(nn_module, ModelWithHooks): - self.net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=True) - else: - self.net = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing - def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): - """ - Computes the actual feature map size given `nn_module` and the target_layer name. - - Args: - input_size: shape of the input tensor - device: the device used to initialise the input tensor - layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. - - Returns: - shape of the actual feature map. - """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): """ Compute the activation map with upsampling and postprocessing. @@ -223,14 +219,7 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): activation maps """ acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) - - # upsampling and postprocessing - if self.upsampler: - img_spatial = x.shape[2:] - acti_map = self.upsampler(img_spatial)(acti_map) - if self.postprocessing: - acti_map = self.postprocessing(acti_map) - return acti_map + return self._upsample_and_post_process(acti_map, x) class GradCAMpp(GradCAM): @@ -251,14 +240,14 @@ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape alpha_nr = grad.pow(2) alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr)) alpha = alpha_nr.div(alpha_dr + 1e-7) - relu_grad = F.relu(self.net.score.exp() * grad) + relu_grad = F.relu(self.nn_module.score.exp() * grad) weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 7cd132c3cf..a0835d86ee 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn +from monai.networks.utils import eval_mode from monai.visualize import NetVisualizer, default_normalizer, default_upsampler try: @@ -275,20 +276,22 @@ def __call__( # type: ignore as the input image. """ - # Check input arguments - x = _check_input_image(x) - class_idx = _check_input_label(self.nn_module, class_idx, x) + with eval_mode(self.nn_module): - # Generate sensitivity image - sensitivity_im, output_im_shape = self._compute_occlusion_sensitivity(x, class_idx, b_box) + # Check input arguments + x = _check_input_image(x) + class_idx = _check_input_label(self.nn_module, class_idx, x) - # upsampling and postprocessing - if self.upsampler is not None: - if np.any(output_im_shape != x.shape[1:]): - img_spatial = tuple(output_im_shape[1:]) - sensitivity_im = self.upsampler(img_spatial)(sensitivity_im) - if self.postprocessing: - sensitivity_im = self.postprocessing(sensitivity_im) + # Generate sensitivity image + sensitivity_im, output_im_shape = self._compute_occlusion_sensitivity(x, class_idx, b_box) - # Squeeze and return - return sensitivity_im + # upsampling and postprocessing + if self.upsampler is not None: + if np.any(output_im_shape != x.shape[1:]): + img_spatial = tuple(output_im_shape[1:]) + sensitivity_im = self.upsampler(img_spatial)(sensitivity_im) + if self.postprocessing: + sensitivity_im = self.postprocessing(sensitivity_im) + + # Squeeze and return + return sensitivity_im diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index 62db5d26ee..964f12b84d 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -11,12 +11,14 @@ import warnings from abc import ABC +from contextlib import nullcontext from typing import Callable, Dict, Sequence, Union import numpy as np import torch import torch.nn.functional as F +from monai.networks.utils import eval_mode from monai.transforms import ScaleIntensity from monai.utils import InterpolateMode, ensure_tuple @@ -128,7 +130,9 @@ def class_score(self, logits, class_idx=None): return logits[:, class_idx].squeeze(), class_idx def __call__(self, x, class_idx=None, retain_graph=False): - logits = self.model(x) + # Can only use eval mode if back grad isn't required + with eval_mode(self.model) if not self.register_backward else nullcontext(): + logits = self.model(x) acti, grad = None, None if self.register_forward: acti = tuple(self.activations[layer] for layer in self.target_layers) @@ -140,11 +144,14 @@ def __call__(self, x, class_idx=None, retain_graph=False): grad = tuple(self.gradients[layer] for layer in self.target_layers) return logits, acti, grad + def get_wrapped_net(self): + return self.model + class NetVisualizer(ABC): def __init__( self, - nn_module: torch.nn.Module, + nn_module: Union[torch.nn.Module, ModelWithHooks], upsampler: Callable, postprocessing: Callable, ) -> None: From 68b194b9f80a9486288f0704436423637343bacf Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 4 Jan 2021 16:53:04 +0000 Subject: [PATCH 09/14] add `train_mode` Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 36 +++++++++++++++++++++++++++++++++++ monai/visualize/visualizer.py | 26 ++++++++++++------------- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 3af177f1f8..231eefd626 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -31,6 +31,7 @@ "icnr_init", "pixelshuffle", "eval_mode", + "train_mode", ] @@ -276,3 +277,38 @@ def eval_mode(*nets: nn.Module): # Return required networks to training for n in training: n.train() + + +@contextmanager +def train_mode(*nets: nn.Module): + """ + Set network(s) to train mode and then return to original state at the end. + + Args: + nets: Input network(s) + + Examples + + .. code-block:: python + + t=torch.rand(1,1,16,16) + p=torch.nn.Conv2d(1,1,3) + p.eval() + print(p.training) # False + with train_mode(p): + print(p.training) # True + print(p(t).sum().backward()) # No exception + """ + + # Get original state of network(s) + eval = [n for n in nets if not n.training] + + try: + # set to train mode + with torch.set_grad_enabled(True): + yield [n.eval() for n in nets] + finally: + # Return required networks to eval + for n in eval: + n.eval() + diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index 964f12b84d..f9d270f396 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -11,14 +11,13 @@ import warnings from abc import ABC -from contextlib import nullcontext from typing import Callable, Dict, Sequence, Union import numpy as np import torch import torch.nn.functional as F -from monai.networks.utils import eval_mode +from monai.networks.utils import eval_mode, train_mode from monai.transforms import ScaleIntensity from monai.utils import InterpolateMode, ensure_tuple @@ -130,18 +129,19 @@ def class_score(self, logits, class_idx=None): return logits[:, class_idx].squeeze(), class_idx def __call__(self, x, class_idx=None, retain_graph=False): - # Can only use eval mode if back grad isn't required - with eval_mode(self.model) if not self.register_backward else nullcontext(): + # Use train_mode if grad is required, else eval_mode + mode = train_mode if self.register_backward else eval_mode + with mode(self.model): logits = self.model(x) - acti, grad = None, None - if self.register_forward: - acti = tuple(self.activations[layer] for layer in self.target_layers) - if self.register_backward: - score, class_idx = self.class_score(logits, class_idx) - self.model.zero_grad() - self.score, self.class_idx = score, class_idx - score.sum().backward(retain_graph=retain_graph) - grad = tuple(self.gradients[layer] for layer in self.target_layers) + acti, grad = None, None + if self.register_forward: + acti = tuple(self.activations[layer] for layer in self.target_layers) + if self.register_backward: + score, class_idx = self.class_score(logits, class_idx) + self.model.zero_grad() + self.score, self.class_idx = score, class_idx + score.sum().backward(retain_graph=retain_graph) + grad = tuple(self.gradients[layer] for layer in self.target_layers) return logits, acti, grad def get_wrapped_net(self): From 538936de57944c174a9de93b79f171904b438697 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 4 Jan 2021 17:35:43 +0000 Subject: [PATCH 10/14] small changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_train_mode.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/test_train_mode.py diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py new file mode 100644 index 0000000000..b12f80f5c7 --- /dev/null +++ b/tests/test_train_mode.py @@ -0,0 +1,31 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.networks.utils import train_mode + + +class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): + t = torch.rand(1, 1, 4, 4) + p = torch.nn.Conv2d(1, 1, 3) + p.eval() + self.assertFalse(p.training) # False + with train_mode(p): + self.assertFalse(p.training) # True + p(t).sum().backward() + + +if __name__ == "__main__": + unittest.main() From deef782c2134f12c2b63a171910807e6d903bff6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 4 Jan 2021 22:15:25 +0000 Subject: [PATCH 11/14] small changes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 2 +- monai/visualize/occlusion_sensitivity.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 231eefd626..9e19773334 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -306,7 +306,7 @@ def train_mode(*nets: nn.Module): try: # set to train mode with torch.set_grad_enabled(True): - yield [n.eval() for n in nets] + yield [n.train() for n in nets] finally: # Return required networks to eval for n in eval: diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index a0835d86ee..a34a4daaaa 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -33,7 +33,6 @@ def _check_input_image(image): # Only accept batch size of 1 if image.shape[0] > 1: raise RuntimeError("Expected batch size of 1.") - return image def _check_input_label(model, label, image): @@ -229,7 +228,7 @@ def _compute_occlusion_sensitivity(self, x, class_idx, b_box): max_idx = [min(j, i + self.margin) for i, j in zip(idx, im_shape)] # Clone and replace target area with `pad_val` - occlu_im = x.clone() + occlu_im = x.detach().clone() occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val # Add to list @@ -279,7 +278,7 @@ def __call__( # type: ignore with eval_mode(self.nn_module): # Check input arguments - x = _check_input_image(x) + _check_input_image(x) class_idx = _check_input_label(self.nn_module, class_idx, x) # Generate sensitivity image From 05668c38b8041149a71f0507a093ef08ca71d996 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 6 Jan 2021 14:05:54 +0000 Subject: [PATCH 12/14] occ sens add verbose option Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/occlusion_sensitivity.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index a34a4daaaa..8ac95a2577 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -136,6 +136,7 @@ def __init__( stride: Union[int, Sequence] = 1, upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, + verbose: bool = True, ) -> None: """Occlusion sensitivitiy constructor. @@ -163,6 +164,7 @@ def __init__( dimensions of input. :param postprocessing: a callable that applies on the upsampled output image. default is normalising between 0 and 1. + :param verbose: use ``tdqm.trange`` output (if available). """ super().__init__( @@ -175,6 +177,7 @@ def __init__( self.margin = margin self.n_batch = n_batch self.stride = stride + self.verbose = verbose def _compute_occlusion_sensitivity(self, x, class_idx, b_box): @@ -213,7 +216,8 @@ def _compute_occlusion_sensitivity(self, x, class_idx, b_box): num_required_predictions = np.prod(downsampled_im_shape) # Loop 1D over image - for i in trange(num_required_predictions): + verbose_range = trange if self.verbose else range + for i in verbose_range(num_required_predictions): # Get corresponding ND index idx = np.unravel_index(i, downsampled_im_shape) # Multiply by stride From a21eae8d3fc7d944ba600bdfe0c34c3b9df2c2cb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 6 Jan 2021 14:34:43 +0000 Subject: [PATCH 13/14] remove NetVisualizer base class Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 1 - monai/visualize/__init__.py | 2 +- monai/visualize/class_activation_maps.py | 125 +++++++++++++++++++++-- monai/visualize/occlusion_sensitivity.py | 13 +-- monai/visualize/visualizer.py | 123 +--------------------- 5 files changed, 124 insertions(+), 140 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 9e19773334..61e859d602 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -311,4 +311,3 @@ def train_mode(*nets: nn.Module): # Return required networks to eval for n in eval: n.eval() - diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 20b4574794..ea66d9dcf7 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .visualizer import ModelWithHooks, NetVisualizer, default_normalizer, default_upsampler # isort:skip +from .visualizer import default_normalizer, default_upsampler # isort:skip from .class_activation_maps import * from .img2tensorboard import * from .occlusion_sensitivity import OcclusionSensitivity diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index fa13cf1f87..33808f28e8 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -9,18 +9,119 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Union +import warnings +from typing import Callable, Dict, Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F -from monai.visualize import ModelWithHooks, NetVisualizer, default_normalizer, default_upsampler +from monai.networks.utils import eval_mode, train_mode +from monai.utils import ensure_tuple +from monai.visualize import default_normalizer, default_upsampler -__all__ = ["CAM", "GradCAM", "GradCAMpp"] +__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks"] -class CAMBase(NetVisualizer): +class ModelWithHooks: + """ + A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. + """ + + def __init__( + self, + nn_module, + target_layer_names: Union[str, Sequence[str]], + register_forward: bool = False, + register_backward: bool = False, + ): + """ + + Args: + nn_module: the model to be wrapped. + target_layer_names: the names of the layer to cache. + register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. + register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. + """ + self.model = nn_module + self.target_layers = ensure_tuple(target_layer_names) + + self.gradients: Dict[str, torch.Tensor] = {} + self.activations: Dict[str, torch.Tensor] = {} + self.score = None + self.class_idx = None + self.register_backward = register_backward + self.register_forward = register_forward + + _registered = [] + for name, mod in nn_module.named_modules(): + if name not in self.target_layers: + continue + _registered.append(name) + if self.register_backward: + mod.register_backward_hook(self.backward_hook(name)) + if self.register_forward: + mod.register_forward_hook(self.forward_hook(name)) + if len(_registered) != len(self.target_layers): + warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") + + def backward_hook(self, name): + def _hook(_module, _grad_input, grad_output): + self.gradients[name] = grad_output[0] + + return _hook + + def forward_hook(self, name): + def _hook(_module, _input, output): + self.activations[name] = output + + return _hook + + def get_layer(self, layer_id: Union[str, Callable]): + """ + + Args: + layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`, + this method will return the module `self.model.fc`. + + Returns: + a submodule from self.model. + """ + if callable(layer_id): + return layer_id(self.model) + if isinstance(layer_id, str): + for name, mod in self.model.named_modules(): + if name == layer_id: + return mod + raise NotImplementedError(f"Could not find {layer_id}.") + + def class_score(self, logits, class_idx=None): + if class_idx is not None: + return logits[:, class_idx].squeeze(), class_idx + class_idx = logits.max(1)[-1] + return logits[:, class_idx].squeeze(), class_idx + + def __call__(self, x, class_idx=None, retain_graph=False): + # Use train_mode if grad is required, else eval_mode + mode = train_mode if self.register_backward else eval_mode + with mode(self.model): + logits = self.model(x) + acti, grad = None, None + if self.register_forward: + acti = tuple(self.activations[layer] for layer in self.target_layers) + if self.register_backward: + score, class_idx = self.class_score(logits, class_idx) + self.model.zero_grad() + self.score, self.class_idx = score, class_idx + score.sum().backward(retain_graph=retain_graph) + grad = tuple(self.gradients[layer] for layer in self.target_layers) + return logits, acti, grad + + def get_wrapped_net(self): + return self.model + + +class CAMBase: """ Base class for CAM methods. """ @@ -35,15 +136,14 @@ def __init__( ) -> None: # Convert to model with hooks if necessary if not isinstance(nn_module, ModelWithHooks): - net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=register_backward) + self.nn_module = ModelWithHooks( + nn_module, target_layers, register_forward=True, register_backward=register_backward + ) else: - net = nn_module + self.nn_module = nn_module - super().__init__( - nn_module=net, - upsampler=upsampler, - postprocessing=postprocessing, - ) + self.upsampler = upsampler + self.postprocessing = postprocessing def feature_map_size(self, input_size, device="cpu", layer_idx=-1): """ @@ -57,6 +157,9 @@ def feature_map_size(self, input_size, device="cpu", layer_idx=-1): """ return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape + def compute_map(self, x, class_idx=None, layer_idx=-1): + raise NotImplementedError() + def _upsample_and_post_process(self, acti_map, x): # upsampling and postprocessing if self.upsampler: diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py index 8ac95a2577..00935f1aaa 100644 --- a/monai/visualize/occlusion_sensitivity.py +++ b/monai/visualize/occlusion_sensitivity.py @@ -18,7 +18,7 @@ import torch.nn as nn from monai.networks.utils import eval_mode -from monai.visualize import NetVisualizer, default_normalizer, default_upsampler +from monai.visualize import default_normalizer, default_upsampler try: from tqdm import trange @@ -84,7 +84,7 @@ def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): return torch.cat((sensitivity_im, scores)) -class OcclusionSensitivity(NetVisualizer): +class OcclusionSensitivity: """ This class computes the occlusion sensitivity for a model's prediction of a given image. By occlusion sensitivity, we mean how the probability of a given @@ -167,12 +167,9 @@ def __init__( :param verbose: use ``tdqm.trange`` output (if available). """ - super().__init__( - nn_module=nn_module, - upsampler=upsampler, - postprocessing=postprocessing, - ) - + self.nn_module = nn_module + self.upsampler = upsampler + self.postprocessing = postprocessing self.pad_val = pad_val self.margin = margin self.n_batch = n_batch diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py index f9d270f396..9a56e0781d 100644 --- a/monai/visualize/visualizer.py +++ b/monai/visualize/visualizer.py @@ -9,19 +9,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from abc import ABC -from typing import Callable, Dict, Sequence, Union + +from typing import Callable import numpy as np import torch import torch.nn.functional as F -from monai.networks.utils import eval_mode, train_mode from monai.transforms import ScaleIntensity -from monai.utils import InterpolateMode, ensure_tuple +from monai.utils import InterpolateMode -__all__ = ["default_upsampler", "default_normalizer", "ModelWithHooks", "NetVisualizer"] +__all__ = ["default_upsampler", "default_normalizer"] def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: @@ -48,116 +46,3 @@ def default_normalizer(x) -> np.ndarray: scaler = ScaleIntensity(minv=1.0, maxv=0.0) x = [scaler(x) for x in x] return np.stack(x, axis=0) - - -class ModelWithHooks: - """ - A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. - """ - - def __init__( - self, - nn_module, - target_layer_names: Union[str, Sequence[str]], - register_forward: bool = False, - register_backward: bool = False, - ): - """ - - Args: - nn_module: the model to be wrapped. - target_layer_names: the names of the layer to cache. - register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. - register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. - """ - self.model = nn_module - self.target_layers = ensure_tuple(target_layer_names) - - self.gradients: Dict[str, torch.Tensor] = {} - self.activations: Dict[str, torch.Tensor] = {} - self.score = None - self.class_idx = None - self.register_backward = register_backward - self.register_forward = register_forward - - _registered = [] - for name, mod in nn_module.named_modules(): - if name not in self.target_layers: - continue - _registered.append(name) - if self.register_backward: - mod.register_backward_hook(self.backward_hook(name)) - if self.register_forward: - mod.register_forward_hook(self.forward_hook(name)) - if len(_registered) != len(self.target_layers): - warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") - - def backward_hook(self, name): - def _hook(_module, _grad_input, grad_output): - self.gradients[name] = grad_output[0] - - return _hook - - def forward_hook(self, name): - def _hook(_module, _input, output): - self.activations[name] = output - - return _hook - - def get_layer(self, layer_id: Union[str, Callable]): - """ - - Args: - layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`, - this method will return the module `self.model.fc`. - - Returns: - a submodule from self.model. - """ - if callable(layer_id): - return layer_id(self.model) - if isinstance(layer_id, str): - for name, mod in self.model.named_modules(): - if name == layer_id: - return mod - raise NotImplementedError(f"Could not find {layer_id}.") - - def class_score(self, logits, class_idx=None): - if class_idx is not None: - return logits[:, class_idx].squeeze(), class_idx - class_idx = logits.max(1)[-1] - return logits[:, class_idx].squeeze(), class_idx - - def __call__(self, x, class_idx=None, retain_graph=False): - # Use train_mode if grad is required, else eval_mode - mode = train_mode if self.register_backward else eval_mode - with mode(self.model): - logits = self.model(x) - acti, grad = None, None - if self.register_forward: - acti = tuple(self.activations[layer] for layer in self.target_layers) - if self.register_backward: - score, class_idx = self.class_score(logits, class_idx) - self.model.zero_grad() - self.score, self.class_idx = score, class_idx - score.sum().backward(retain_graph=retain_graph) - grad = tuple(self.gradients[layer] for layer in self.target_layers) - return logits, acti, grad - - def get_wrapped_net(self): - return self.model - - -class NetVisualizer(ABC): - def __init__( - self, - nn_module: Union[torch.nn.Module, ModelWithHooks], - upsampler: Callable, - postprocessing: Callable, - ) -> None: - self.nn_module = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing - - def __call__(self): - raise NotImplementedError() From 00748f9d92d865c59e05275299383fd57f3b0ad6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 6 Jan 2021 16:11:35 +0000 Subject: [PATCH 14/14] fix test_train_mode.oy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_train_mode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py index b12f80f5c7..2ed48bcb15 100644 --- a/tests/test_train_mode.py +++ b/tests/test_train_mode.py @@ -23,7 +23,7 @@ def test_eval_mode(self): p.eval() self.assertFalse(p.training) # False with train_mode(p): - self.assertFalse(p.training) # True + self.assertTrue(p.training) # True p(t).sum().backward()