diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 0bcfbd4240..32a3faf380 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -37,7 +37,3 @@ Metrics .. autoclass:: SurfaceDistanceMetric :members: - -`Occlusion sensitivity` ------------------------ -.. autofunction:: compute_occlusion_sensitivity \ No newline at end of file diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst index 9668d48114..850fd51770 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualize.rst @@ -17,4 +17,10 @@ Class activation map -------------------- .. automodule:: monai.visualize.class_activation_maps - :members: \ No newline at end of file + :members: + +Occlusion sensitivity +--------------------- + +.. automodule:: monai.visualize.occlusion_sensitivity + :members: diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index a9d7d661ec..b9697d008f 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -174,7 +174,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No class TensorBoardImageHandler(object): """ - TensorBoardImageHandler is an Ignite Event handler that can visualise images, labels and outputs as 2D/3D images. + TensorBoardImageHandler is an Ignite Event handler that can visualize images, labels and outputs as 2D/3D images. 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index a0d626f45b..f43fb444db 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -12,7 +12,6 @@ from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .hausdorff_distance import * from .meandice import DiceMetric, compute_meandice -from .occlusion_sensitivity import compute_occlusion_sensitivity from .rocauc import compute_roc_auc from .surface_distance import SurfaceDistanceMetric, compute_average_surface_distance from .utils import * diff --git a/monai/metrics/occlusion_sensitivity.py b/monai/metrics/occlusion_sensitivity.py deleted file mode 100644 index 9879f472a9..0000000000 --- a/monai/metrics/occlusion_sensitivity.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020 MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections.abc import Sequence -from functools import partial -from typing import Optional, Union - -import numpy as np -import torch -import torch.nn as nn - -try: - from tqdm import trange - - trange = partial(trange, desc="Computing occlusion sensitivity") -except (ImportError, AttributeError): - trange = range - - -def _check_input_image(image): - """Check that the input image is as expected.""" - # Only accept batch size of 1 - if image.shape[0] > 1: - raise RuntimeError("Expected batch size of 1.") - return image - - -def _check_input_label(label, image): - """Check that the input label is as expected.""" - # If necessary turn the label into a 1-element tensor - if isinstance(label, int): - label = torch.tensor([[label]], dtype=torch.int64).to(image.device) - # If the label is a tensor, make sure there's only 1 element - elif label.numel() != image.shape[0]: - raise RuntimeError("Expected as many labels as batches.") - return label - - -def _check_input_bounding_box(b_box, im_shape): - """Check that the bounding box (if supplied) is as expected.""" - # If no bounding box has been supplied, set min and max to None - if b_box is None: - b_box_min = b_box_max = None - - # Bounding box has been supplied - else: - # Should be twice as many elements in `b_box` as `im_shape` - if len(b_box) != 2 * len(im_shape): - raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") - - # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively. - b_box_min = np.array(b_box[::2]) - b_box_max = np.array(b_box[1::2]) - b_box_min[b_box_min < 0] = 0 - b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 - # Check all max's are < im_shape - if np.any(b_box_max >= im_shape): - raise ValueError("Max bounding box should be < image size for all values") - # Check all min's are <= max's - if np.any(b_box_min > b_box_max): - raise ValueError("Min bounding box should be <= max for all values") - - return b_box_min, b_box_max - - -def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): - """For given number of images, get probability of predicting - a given label. Append to previous evaluations.""" - batch_images = torch.cat(batch_images, dim=0) - batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device) - scores = model(batch_images).detach().gather(1, batch_ids) - return torch.cat((sensitivity_im, scores)) - - -def compute_occlusion_sensitivity( - model: nn.Module, - image: torch.Tensor, - label: Union[int, torch.Tensor], - pad_val: float = 0.0, - margin: Union[int, Sequence] = 2, - n_batch: int = 128, - b_box: Optional[Sequence] = None, - stride: Union[int, Sequence] = 1, - upsample_mode: str = "nearest", -) -> np.ndarray: - """ - This function computes the occlusion sensitivity for a model's prediction - of a given image. By occlusion sensitivity, we mean how the probability of a given - prediction changes as the occluded section of an image changes. This can - be useful to understand why a network is making certain decisions. - - The result is given as ``baseline`` (the probability of - a certain output) minus the probability of the output with the occluded - area. - - Therefore, higher values in the output image mean there was a - greater the drop in certainty, indicating the occluded region was more - important in the decision process. - - See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via - Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 - - Args: - model: classification model to use for inference - image: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. - label: classification label to check for changes (normally the true - label, but doesn't have to be) - pad_val: when occluding part of the image, which values should we put - in the image? - margin: we'll create a cuboid/cube around the voxel to be occluded. if - ``margin==2``, then we'll create a cube that is +/- 2 voxels in - all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` - can be supplied to have a margin of different sizes (i.e., create - a cuboid). - n_batch: number of images in a batch before inference. - b_box: Bounding box on which to perform the analysis. The output image - will also match in size. There should be a minimum and maximum for - all dimensions except batch: ``[min1, max1, min2, max2,...]``. - * By default, the whole image will be used. Decreasing the size will - speed the analysis up, which might be useful for larger images. - * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). - * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. - stride: Stride for performing occlusions. Can be single value or sequence - (for varying stride in the different directions). Should be >= 1. - upsample_mode: If stride != 1 is used, we'll upsample such that the size - of the voxels in the output image match the input. Upsampling is done with - ``torch.nn.Upsample``, and mode can be set to: - * ``nearest``, ``linear``, ``bilinear``, ``bicubic`` and ``trilinear`` - * default is ``nearest``. - Returns: - Numpy array. If no bounding box is supplied, this will be the same size - as the input image. If a bounding box is used, the output image will be - cropped to this size. - """ - - # Check input arguments - image = _check_input_image(image) - label = _check_input_label(label, image) - im_shape = np.array(image.shape[1:]) - b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) - - # Get baseline probability - baseline = model(image).detach()[0, label].item() - - # Create some lists - batch_images = [] - batch_ids = [] - - sensitivity_im = torch.empty(0, dtype=torch.float32, device=image.device) - - # If no bounding box supplied, output shape is same as input shape. - # If bounding box is present, shape is max - min + 1 - output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 - - # Calculate the downsampled shape - if not isinstance(stride, Sequence): - stride_np = np.full_like(im_shape, stride, dtype=np.int32) - stride_np[0] = 1 # always do stride 1 in channel dimension - else: - # Convert to numpy array and check dimensions match - stride_np = np.array(stride, dtype=np.int32) - if stride_np.size != im_shape.size: - raise ValueError("Sizes of image shape and stride should match.") - - # Obviously if stride = 1, downsampled_im_shape == output_im_shape - downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) - downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 - num_required_predictions = np.prod(downsampled_im_shape) - - # Loop 1D over image - for i in trange(num_required_predictions): - # Get corresponding ND index - idx = np.unravel_index(i, downsampled_im_shape) - # Multiply by stride - idx *= stride_np - # If a bounding box is being used, we need to add on - # the min to shift to start of region of interest - if b_box_min is not None: - idx += b_box_min - - # Get min and max index of box to occlude - min_idx = [max(0, i - margin) for i in idx] - max_idx = [min(j, i + margin) for i, j in zip(idx, im_shape)] - - # Clone and replace target area with `pad_val` - occlu_im = image.clone() - occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val - - # Add to list - batch_images.append(occlu_im) - batch_ids.append(label) - - # Once the batch is complete (or on last iteration) - if len(batch_images) == n_batch or i == num_required_predictions - 1: - # Do the predictions and append to sensitivity map - sensitivity_im = _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im) - # Clear lists - batch_images = [] - batch_ids = [] - - # Subtract from baseline - sensitivity_im = baseline - sensitivity_im - - # Reshape to match downsampled image - sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)) - - # If necessary, upsample - if np.any(stride_np != 1): - output_im_shape = tuple(output_im_shape[1:]) # needs to be given as 3D tuple - upsampler = nn.Upsample(size=output_im_shape, mode=upsample_mode) - sensitivity_im = upsampler(sensitivity_im.unsqueeze(0)) - - # Convert tensor to numpy - sensitivity_im = sensitivity_im.cpu().numpy() - - # Squeeze and return - return np.squeeze(sensitivity_im) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 3af177f1f8..61e859d602 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -31,6 +31,7 @@ "icnr_init", "pixelshuffle", "eval_mode", + "train_mode", ] @@ -276,3 +277,37 @@ def eval_mode(*nets: nn.Module): # Return required networks to training for n in training: n.train() + + +@contextmanager +def train_mode(*nets: nn.Module): + """ + Set network(s) to train mode and then return to original state at the end. + + Args: + nets: Input network(s) + + Examples + + .. code-block:: python + + t=torch.rand(1,1,16,16) + p=torch.nn.Conv2d(1,1,3) + p.eval() + print(p.training) # False + with train_mode(p): + print(p.training) # True + print(p(t).sum().backward()) # No exception + """ + + # Get original state of network(s) + eval = [n for n in nets if not n.training] + + try: + # set to train mode + with torch.set_grad_enabled(True): + yield [n.train() for n in nets] + finally: + # Return required networks to eval + for n in eval: + n.eval() diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index 2fbd1dcf66..ea66d9dcf7 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -9,5 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .visualizer import default_normalizer, default_upsampler # isort:skip from .class_activation_maps import * from .img2tensorboard import * +from .occlusion_sensitivity import OcclusionSensitivity diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 6fd29d1c96..33808f28e8 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -12,14 +12,15 @@ import warnings from typing import Callable, Dict, Sequence, Union -import numpy as np import torch +import torch.nn as nn import torch.nn.functional as F -from monai.transforms import ScaleIntensity -from monai.utils import InterpolateMode, ensure_tuple +from monai.networks.utils import eval_mode, train_mode +from monai.utils import ensure_tuple +from monai.visualize import default_normalizer, default_upsampler -__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM", "GradCAMpp"] +__all__ = ["CAM", "GradCAM", "GradCAMpp", "ModelWithHooks"] class ModelWithHooks: @@ -101,46 +102,78 @@ def class_score(self, logits, class_idx=None): return logits[:, class_idx].squeeze(), class_idx def __call__(self, x, class_idx=None, retain_graph=False): - logits = self.model(x) - acti, grad = None, None - if self.register_forward: - acti = tuple(self.activations[layer] for layer in self.target_layers) - if self.register_backward: - score, class_idx = self.class_score(logits, class_idx) - self.model.zero_grad() - self.score, self.class_idx = score, class_idx - score.sum().backward(retain_graph=retain_graph) - grad = tuple(self.gradients[layer] for layer in self.target_layers) + # Use train_mode if grad is required, else eval_mode + mode = train_mode if self.register_backward else eval_mode + with mode(self.model): + logits = self.model(x) + acti, grad = None, None + if self.register_forward: + acti = tuple(self.activations[layer] for layer in self.target_layers) + if self.register_backward: + score, class_idx = self.class_score(logits, class_idx) + self.model.zero_grad() + self.score, self.class_idx = score, class_idx + score.sum().backward(retain_graph=retain_graph) + grad = tuple(self.gradients[layer] for layer in self.target_layers) return logits, acti, grad + def get_wrapped_net(self): + return self.model + -def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: +class CAMBase: """ - A linear interpolation method for upsampling the feature map. - The output of this function is a callable `func`, - such that `func(activation_map)` returns an upsampled tensor. + Base class for CAM methods. """ - def up(acti_map): - linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] - interp_mode = linear_mode[len(spatial_size) - 1] - return F.interpolate(acti_map, size=spatial_size, mode=str(interp_mode.value), align_corners=False) + def __init__( + self, + nn_module: nn.Module, + target_layers: str, + upsampler: Callable = default_upsampler, + postprocessing: Callable = default_normalizer, + register_backward: bool = True, + ) -> None: + # Convert to model with hooks if necessary + if not isinstance(nn_module, ModelWithHooks): + self.nn_module = ModelWithHooks( + nn_module, target_layers, register_forward=True, register_backward=register_backward + ) + else: + self.nn_module = nn_module - return up + self.upsampler = upsampler + self.postprocessing = postprocessing + def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + """ + Computes the actual feature map size given `nn_module` and the target_layer name. + Args: + input_size: shape of the input tensor + device: the device used to initialise the input tensor + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + Returns: + shape of the actual feature map. + """ + return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape -def default_normalizer(acti_map) -> np.ndarray: - """ - A linear intensity scaling by mapping the (min, max) to (1, 0). - """ - if isinstance(acti_map, torch.Tensor): - acti_map = acti_map.detach().cpu().numpy() - scaler = ScaleIntensity(minv=1.0, maxv=0.0) - acti_map = [scaler(x) for x in acti_map] - return np.stack(acti_map, axis=0) + def compute_map(self, x, class_idx=None, layer_idx=-1): + raise NotImplementedError() + + def _upsample_and_post_process(self, acti_map, x): + # upsampling and postprocessing + if self.upsampler: + img_spatial = x.shape[2:] + acti_map = self.upsampler(img_spatial)(acti_map) + if self.postprocessing: + acti_map = self.postprocessing(acti_map) + return acti_map + def __call__(self): + raise NotImplementedError() -class CAM: + +class CAM(CAMBase): """ Compute class activation map from the last fully-connected layers before the spatial pooling. @@ -172,83 +205,65 @@ class CAM: def __init__( self, - nn_module, + nn_module: nn.Module, target_layers: str, fc_layers: Union[str, Callable] = "fc", - upsampler=default_upsampler, + upsampler: Callable = default_upsampler, postprocessing: Callable = default_normalizer, - ): + ) -> None: """ - Args: - nn_module: the model to be visualised + nn_module: the model to be visualized target_layers: name of the model layer to generate the feature map. fc_layers: a string or a callable used to get fully-connected weights to compute activation map from the target_layers (without pooling). and evaluate it at every spatial location. - upsampler: an upsampling method to upsample the feature map. - postprocessing: a callable that applies on the upsampled feature map. + upsampler: An upsampling method to upsample the output image. Default is + N dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + postprocessing: a callable that applies on the upsampled output image. + default is normalising between 0 and 1. """ - if not isinstance(nn_module, ModelWithHooks): - self.net = ModelWithHooks(nn_module, target_layers, register_forward=True) - else: - self.net = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing + super().__init__( + nn_module=nn_module, + target_layers=target_layers, + upsampler=upsampler, + postprocessing=postprocessing, + register_backward=False, + ) self.fc_layers = fc_layers def compute_map(self, x, class_idx=None, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, _ = self.net(x) + logits, acti, _ = self.nn_module(x) acti = acti[layer_idx] if class_idx is None: class_idx = logits.max(1)[-1] b, c, *spatial = acti.shape acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D - fc_layers = self.net.get_layer(self.fc_layers) + fc_layers = self.nn_module.get_layer(self.fc_layers) output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2) output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0) return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): - """ - Computes the actual feature map size given `nn_module` and the target_layer name. - - Args: - input_size: shape of the input tensor - device: the device used to initialise the input tensor - layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. - - Returns: - shape of the actual feature map. - """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, layer_idx=-1): """ Compute the activation map with upsampling and postprocessing. Args: x: input tensor, shape must be compatible with `nn_module`. - class_idx: index of the class to be visualised. Default to argmax(logits) + class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. Returns: activation maps """ acti_map = self.compute_map(x, class_idx, layer_idx) - - # upsampling and postprocessing - if self.upsampler: - img_spatial = x.shape[2:] - acti_map = self.upsampler(img_spatial)(acti_map) - if self.postprocessing: - acti_map = self.postprocessing(acti_map) - return acti_map + return self._upsample_and_post_process(acti_map, x) -class GradCAM: +class GradCAM(CAMBase): """ Computes Gradient-weighted Class Activation Mapping (Grad-CAM). This implementation is based on: @@ -282,54 +297,24 @@ class GradCAM: """ - def __init__(self, nn_module, target_layers: str, upsampler=default_upsampler, postprocessing=default_normalizer): - """ - - Args: - nn_module: the model to be used to generate the visualisations. - target_layers: name of the model layer to generate the feature map. - upsampler: an upsampling method to upsample the feature map. - postprocessing: a callable that applies on the upsampled feature map. - """ - if not isinstance(nn_module, ModelWithHooks): - self.net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=True) - else: - self.net = nn_module - self.upsampler = upsampler - self.postprocessing = postprocessing - def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) - def feature_map_size(self, input_size, device="cpu", layer_idx=-1): - """ - Computes the actual feature map size given `nn_module` and the target_layer name. - - Args: - input_size: shape of the input tensor - device: the device used to initialise the input tensor - layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. - - Returns: - shape of the actual feature map. - """ - return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): """ Compute the activation map with upsampling and postprocessing. Args: x: input tensor, shape must be compatible with `nn_module`. - class_idx: index of the class to be visualised. Default to argmax(logits) + class_idx: index of the class to be visualized. Default to argmax(logits) layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. retain_graph: whether to retain_graph for torch module backward call. @@ -337,14 +322,7 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): activation maps """ acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) - - # upsampling and postprocessing - if self.upsampler: - img_spatial = x.shape[2:] - acti_map = self.upsampler(img_spatial)(acti_map) - if self.postprocessing: - acti_map = self.postprocessing(acti_map) - return acti_map + return self._upsample_and_post_process(acti_map, x) class GradCAMpp(GradCAM): @@ -365,14 +343,14 @@ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): """ Compute the actual feature map with input tensor `x`. """ - logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape alpha_nr = grad.pow(2) alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr)) alpha = alpha_nr.div(alpha_dr + 1e-7) - relu_grad = F.relu(self.net.score.exp() * grad) + relu_grad = F.relu(self.nn_module.score.exp() * grad) weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py new file mode 100644 index 0000000000..00935f1aaa --- /dev/null +++ b/monai/visualize/occlusion_sensitivity.py @@ -0,0 +1,297 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from functools import partial +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from monai.networks.utils import eval_mode +from monai.visualize import default_normalizer, default_upsampler + +try: + from tqdm import trange + + trange = partial(trange, desc="Computing occlusion sensitivity") +except (ImportError, AttributeError): + trange = range + + +def _check_input_image(image): + """Check that the input image is as expected.""" + # Only accept batch size of 1 + if image.shape[0] > 1: + raise RuntimeError("Expected batch size of 1.") + + +def _check_input_label(model, label, image): + """Check that the input label is as expected.""" + if label is None: + label = model(image).argmax(1) + # If necessary turn the label into a 1-element tensor + elif not isinstance(label, torch.Tensor): + label = torch.tensor([[label]], dtype=torch.int64).to(image.device) + # make sure there's only 1 element + if label.numel() != image.shape[0]: + raise RuntimeError("Expected as many labels as batches.") + return label + + +def _check_input_bounding_box(b_box, im_shape): + """Check that the bounding box (if supplied) is as expected.""" + # If no bounding box has been supplied, set min and max to None + if b_box is None: + b_box_min = b_box_max = None + + # Bounding box has been supplied + else: + # Should be twice as many elements in `b_box` as `im_shape` + if len(b_box) != 2 * len(im_shape): + raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)") + + # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively. + b_box_min = np.array(b_box[::2]) + b_box_max = np.array(b_box[1::2]) + b_box_min[b_box_min < 0] = 0 + b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1 + # Check all max's are < im_shape + if np.any(b_box_max >= im_shape): + raise ValueError("Max bounding box should be < image size for all values") + # Check all min's are <= max's + if np.any(b_box_min > b_box_max): + raise ValueError("Min bounding box should be <= max for all values") + + return b_box_min, b_box_max + + +def _append_to_sensitivity_im(model, batch_images, batch_ids, sensitivity_im): + """For given number of images, get probability of predicting + a given label. Append to previous evaluations.""" + batch_images = torch.cat(batch_images, dim=0) + batch_ids = torch.LongTensor(batch_ids).unsqueeze(1).to(sensitivity_im.device) + scores = model(batch_images).detach().gather(1, batch_ids) + return torch.cat((sensitivity_im, scores)) + + +class OcclusionSensitivity: + """ + This class computes the occlusion sensitivity for a model's prediction + of a given image. By occlusion sensitivity, we mean how the probability of a given + prediction changes as the occluded section of an image changes. This can + be useful to understand why a network is making certain decisions. + + The result is given as ``baseline`` (the probability of + a certain output) minus the probability of the output with the occluded + area. + + Therefore, higher values in the output image mean there was a + greater the drop in certainty, indicating the occluded region was more + important in the decision process. + + See: R. R. Selvaraju et al. Grad-CAM: Visual Explanations from Deep Networks via + Gradient-based Localization. https://doi.org/10.1109/ICCV.2017.74 + + Examples + + .. code-block:: python + + # densenet 2d + from monai.networks.nets import densenet121 + from monai.visualize import OcclusionSensitivity + + model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + occ_sens = OcclusionSensitivity(nn_module=model_2d) + result = occ_sens(x=torch.rand((1, 1, 48, 64)), class_idx=None, b_box=[-1, -1, 2, 40, 1, 62]) + + # densenet 3d + from monai.networks.nets import DenseNet + from monai.visualize import OcclusionSensitivity + + model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) + occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=2) + result = occ_sens(torch.rand(1, 1, 6, 6, 6), class_idx=1, b_box=[-1, -1, 2, 3, -1, -1, -1, -1]) + + See Also: + + - :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.` + """ + + def __init__( + self, + nn_module: nn.Module, + pad_val: float = 0.0, + margin: Union[int, Sequence] = 2, + n_batch: int = 128, + stride: Union[int, Sequence] = 1, + upsampler: Callable = default_upsampler, + postprocessing: Callable = default_normalizer, + verbose: bool = True, + ) -> None: + """Occlusion sensitivitiy constructor. + + :param nn_module: classification model to use for inference + :param pad_val: when occluding part of the image, which values should we put + in the image? + :param margin: we'll create a cuboid/cube around the voxel to be occluded. if + ``margin==2``, then we'll create a cube that is +/- 2 voxels in + all directions (i.e., a cube of 5 x 5 x 5 voxels). A ``Sequence`` + can be supplied to have a margin of different sizes (i.e., create + a cuboid). + :param n_batch: number of images in a batch before inference. + :param b_box: Bounding box on which to perform the analysis. The output image + will also match in size. There should be a minimum and maximum for + all dimensions except batch: ``[min1, max1, min2, max2,...]``. + * By default, the whole image will be used. Decreasing the size will + speed the analysis up, which might be useful for larger images. + * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). + * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + :param stride: Stride in spatial directions for performing occlusions. Can be single + value or sequence (for varying stride in the different directions). + Should be >= 1. Striding in the channel direction will always be 1. + :param upsampler: An upsampling method to upsample the output image. Default is + N dimensional linear (bilinear, trilinear, etc.) depending on num spatial + dimensions of input. + :param postprocessing: a callable that applies on the upsampled output image. + default is normalising between 0 and 1. + :param verbose: use ``tdqm.trange`` output (if available). + """ + + self.nn_module = nn_module + self.upsampler = upsampler + self.postprocessing = postprocessing + self.pad_val = pad_val + self.margin = margin + self.n_batch = n_batch + self.stride = stride + self.verbose = verbose + + def _compute_occlusion_sensitivity(self, x, class_idx, b_box): + + # Get bounding box + im_shape = np.array(x.shape[1:]) + b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape) + + # Get baseline probability + baseline = self.nn_module(x).detach()[0, class_idx].item() + + # Create some lists + batch_images = [] + batch_ids = [] + + sensitivity_im = torch.empty(0, dtype=torch.float32, device=x.device) + + # If no bounding box supplied, output shape is same as input shape. + # If bounding box is present, shape is max - min + 1 + output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1 + + # Calculate the downsampled shape + if not isinstance(self.stride, Sequence): + stride_np = np.full_like(im_shape, self.stride, dtype=np.int32) + stride_np[0] = 1 # always do stride 1 in channel dimension + else: + # Convert to numpy array and check dimensions match + stride_np = np.array(self.stride, dtype=np.int32) + if stride_np.size != im_shape - 1: # should be 1 less to get spatial dimensions + raise ValueError( + "If supplying stride as sequence, number of elements of stride should match number of spatial dimensions." + ) + + # Obviously if stride = 1, downsampled_im_shape == output_im_shape + downsampled_im_shape = np.floor(output_im_shape / stride_np).astype(np.int32) + downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1 + num_required_predictions = np.prod(downsampled_im_shape) + + # Loop 1D over image + verbose_range = trange if self.verbose else range + for i in verbose_range(num_required_predictions): + # Get corresponding ND index + idx = np.unravel_index(i, downsampled_im_shape) + # Multiply by stride + idx *= stride_np + # If a bounding box is being used, we need to add on + # the min to shift to start of region of interest + if b_box_min is not None: + idx += b_box_min + + # Get min and max index of box to occlude + min_idx = [max(0, i - self.margin) for i in idx] + max_idx = [min(j, i + self.margin) for i, j in zip(idx, im_shape)] + + # Clone and replace target area with `pad_val` + occlu_im = x.detach().clone() + occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = self.pad_val + + # Add to list + batch_images.append(occlu_im) + batch_ids.append(class_idx) + + # Once the batch is complete (or on last iteration) + if len(batch_images) == self.n_batch or i == num_required_predictions - 1: + # Do the predictions and append to sensitivity map + sensitivity_im = _append_to_sensitivity_im(self.nn_module, batch_images, batch_ids, sensitivity_im) + # Clear lists + batch_images = [] + batch_ids = [] + + # Subtract baseline from sensitivity so that +ve values mean more important in decision process + sensitivity_im = baseline - sensitivity_im + + # Reshape to match downsampled image, and unsqueeze to add batch dimension back in + sensitivity_im = sensitivity_im.reshape(tuple(downsampled_im_shape)).unsqueeze(0) + + return sensitivity_im, output_im_shape + + def __call__( # type: ignore + self, x: torch.Tensor, class_idx: Optional[Union[int, torch.Tensor]] = None, b_box: Optional[Sequence] = None + ): + """ + Args: + x: image to test. Should be tensor consisting of 1 batch, can be 2- or 3D. + class_idx: classification label to check for changes. This could be the true + label, or it could be the predicted label, etc. Use ``None`` to use generate + the predicted model. + b_box: Bounding box on which to perform the analysis. The output image + will also match in size. There should be a minimum and maximum for + all dimensions except batch: ``[min1, max1, min2, max2,...]``. + * By default, the whole image will be used. Decreasing the size will + speed the analysis up, which might be useful for larger images. + * Min and max are inclusive, so [0, 63, ...] will have size (64, ...). + * Use -ve to use 0 for min values and im.shape[x]-1 for xth dimension. + Returns: + Depends on the postprocessing, but the default return type is a Numpy array. + The returned image will occupy the same space as the input image, unless a + bounding box is supplied, in which case it will occupy that space. Unless + upsampling is disabled, the output image will have voxels of the same size + as the input image. + """ + + with eval_mode(self.nn_module): + + # Check input arguments + _check_input_image(x) + class_idx = _check_input_label(self.nn_module, class_idx, x) + + # Generate sensitivity image + sensitivity_im, output_im_shape = self._compute_occlusion_sensitivity(x, class_idx, b_box) + + # upsampling and postprocessing + if self.upsampler is not None: + if np.any(output_im_shape != x.shape[1:]): + img_spatial = tuple(output_im_shape[1:]) + sensitivity_im = self.upsampler(img_spatial)(sensitivity_im) + if self.postprocessing: + sensitivity_im = self.postprocessing(sensitivity_im) + + # Squeeze and return + return sensitivity_im diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py new file mode 100644 index 0000000000..9a56e0781d --- /dev/null +++ b/monai/visualize/visualizer.py @@ -0,0 +1,48 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable + +import numpy as np +import torch +import torch.nn.functional as F + +from monai.transforms import ScaleIntensity +from monai.utils import InterpolateMode + +__all__ = ["default_upsampler", "default_normalizer"] + + +def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: + """ + A linear interpolation method for upsampling the feature map. + The output of this function is a callable `func`, + such that `func(x)` returns an upsampled tensor. + """ + + def up(x): + linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] + interp_mode = linear_mode[len(spatial_size) - 1] + return F.interpolate(x, size=spatial_size, mode=str(interp_mode.value), align_corners=False) + + return up + + +def default_normalizer(x) -> np.ndarray: + """ + A linear intensity scaling by mapping the (min, max) to (1, 0). + """ + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + scaler = ScaleIntensity(minv=1.0, maxv=0.0) + x = [scaler(x) for x in x] + return np.stack(x, axis=0) diff --git a/tests/min_tests.py b/tests/min_tests.py index ccfc789992..510e201b94 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -101,7 +101,7 @@ def run_testsuit(): "test_zoom", "test_zoom_affine", "test_zoomd", - "test_compute_occlusion_sensitivity", + "test_occlusion_sensitivity", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_compute_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py similarity index 58% rename from tests/test_compute_occlusion_sensitivity.py rename to tests/test_occlusion_sensitivity.py index 9f30162c47..9f5dc44776 100644 --- a/tests/test_compute_occlusion_sensitivity.py +++ b/tests/test_occlusion_sensitivity.py @@ -14,8 +14,8 @@ import torch from parameterized import parameterized -from monai.metrics import compute_occlusion_sensitivity from monai.networks.nets import DenseNet, densenet121 +from monai.visualize import OcclusionSensitivity device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3).to(device) @@ -28,33 +28,56 @@ # 2D w/ bounding box TEST_CASE_0 = [ { - "model": model_2d, - "image": torch.rand(1, 1, 48, 64).to(device), - "label": torch.tensor([[0]], dtype=torch.int64).to(device), + "nn_module": model_2d, + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + "class_idx": torch.tensor([[0]], dtype=torch.int64).to(device), "b_box": [-1, -1, 2, 40, 1, 62], }, - (39, 62), + (1, 1, 39, 62), ] -# 3D w/ bounding box +# 3D w/ bounding box and stride TEST_CASE_1 = [ { - "model": model_3d, - "image": torch.rand(1, 1, 6, 6, 6).to(device), - "label": 0, - "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], + "nn_module": model_3d, "n_batch": 10, "stride": 2, }, - (2, 6, 6), + { + "x": torch.rand(1, 1, 6, 6, 6).to(device), + "class_idx": None, + "b_box": [-1, -1, 2, 3, -1, -1, -1, -1], + }, + (1, 1, 2, 6, 6), +] + +TEST_CASE_FAIL = [ # 2D should fail, since 3 stride values given + { + "nn_module": model_2d, + "n_batch": 10, + "stride": (2, 2, 2), + }, + { + "x": torch.rand(1, 1, 48, 64).to(device), + "class_idx": None, + "b_box": [-1, -1, 2, 3, -1, -1], + }, ] class TestComputeOcclusionSensitivity(unittest.TestCase): @parameterized.expand([TEST_CASE_0, TEST_CASE_1]) - def test_shape(self, input_data, expected_shape): - result = compute_occlusion_sensitivity(**input_data) + def test_shape(self, init_data, call_data, expected_shape): + occ_sens = OcclusionSensitivity(**init_data) + result = occ_sens(**call_data) self.assertTupleEqual(result.shape, expected_shape) + def test_fail(self): + occ_sens = OcclusionSensitivity(**TEST_CASE_FAIL[0]) + with self.assertRaises(ValueError): + occ_sens(**TEST_CASE_FAIL[1]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_train_mode.py b/tests/test_train_mode.py new file mode 100644 index 0000000000..2ed48bcb15 --- /dev/null +++ b/tests/test_train_mode.py @@ -0,0 +1,31 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.networks.utils import train_mode + + +class TestEvalMode(unittest.TestCase): + def test_eval_mode(self): + t = torch.rand(1, 1, 4, 4) + p = torch.nn.Conv2d(1, 1, 3) + p.eval() + self.assertFalse(p.training) # False + with train_mode(p): + self.assertTrue(p.training) # True + p(t).sum().backward() + + +if __name__ == "__main__": + unittest.main()