From 844958664bdb3f8a03b47a75b3b468acfb2d8180 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Nov 2020 14:56:02 +0000 Subject: [PATCH 1/6] fixes #713 #1212 Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 160 +++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 monai/visualize/class_activation_maps.py diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py new file mode 100644 index 0000000000..49a0bb71b0 --- /dev/null +++ b/monai/visualize/class_activation_maps.py @@ -0,0 +1,160 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from monai.utils import ensure_tuple + + +class ModelWithHooks: + """ + A model wrapper to run model forward and store intermediate forward, backward information. + """ + + def __init__(self, model, target_layers, register_forward: bool = False, register_backward: bool = False): + self.model = model + self.target_layers = ensure_tuple(target_layers) + + self.gradients = {} + self.activations = {} + self.register_backward = register_backward + self.register_forward = register_forward + + for name, mod in model.named_modules(): + if name not in self.target_layers: + continue + if self.register_backward: + mod.register_backward_hook(self.backward_hook(name)) + if self.register_forward: + mod.register_forward_hook(self.forward_hook(name)) + + def backward_hook(self, name): + def _hook(_module, _grad_input, grad_output): + self.gradients[name] = grad_output[0] + + return _hook + + def forward_hook(self, name): + def _hook(_module, _input, output): + self.activations[name] = output + + return _hook + + def class_score(self, logits, class_idx=None): + if class_idx is not None: + return logits[:, class_idx].squeeze() + return logits[:, logits.argmax(1)].squeeze() + + def __call__(self, x, class_idx=None, retain_graph=False): + logits = self.model(x) + acti, grad = None, None + if self.register_forward: + acti = tuple(self.activations[layer] for layer in self.target_layers) + if self.register_backward: + score = self.class_score(logits, class_idx) + self.model.zero_grad() + score.backward(retain_graph=retain_graph) + grad = tuple(self.gradients[layer] for layer in self.target_layers) + return logits, acti, grad + + +class CAM: + def __init__(self, model, target_layers): + self.net = ModelWithHooks(model, target_layers, register_forward=True) + + def norm_features(self, feature): + feature -= feature.min() + feature /= feature.max() + return 1.0 - feature + + def __call__(self, x, class_idx=None): + logits, acti, _ = self.net(x) + acti = acti[0] + if class_idx is None: + class_idx = torch.argmax(logits, dim=1) + b, c, *spatial = acti.shape + weight = self.net.model.fc.weight[class_idx].view(c, *[1 for _ in spatial]) + map = [self.norm_features((weight * a).sum(0)) for a in acti] + return torch.stack(map, dim=0) + + +class GradCAM: + def __init__(self, model, target_layers): + self.net = ModelWithHooks(model, target_layers, register_forward=True, register_backward=True) + + def norm_features(self, feature): + feature -= feature.min() + feature /= feature.max() + return 1.0 - feature + + def __call__(self, x): + logits, acti, grad = self.net(x) + acti = acti[0] + grad = grad[0] + b, c, *spatial = grad.shape + grad_ave = grad.view(b, c, -1).mean(2) + weights = grad_ave.view(b, c, 1, 1) + map = (weights * acti).sum(1) + return self.norm_features(map) + + +# if __name__ == "__main__": +# from torchvision import transforms +# import glob +# import numpy as np +# import PIL +# import cv2 +# from matplotlib import pyplot as plt +# import argparse +# +# parser = argparse.ArgumentParser() +# parser.add_argument("--device", default="0", type=str) +# parser.add_argument( +# "--target_layer", +# default="layer4", +# type=str, +# help="Specify the name of the target layer (before the pooling layer)", +# ) +# parser.add_argument( +# "--final_layer", default="fc", type=str, help="Specify the name of the last classification layer" +# ) +# args = parser.parse_args() +# device = torch.device("cuda:" + args.device) if torch.cuda.is_available() else torch.device("cpu") +# model = torch.load("temp/resnet-cam.pt", map_location=device) +# # print(model) +# if torch.cuda.is_available(): +# model.cuda() +# model.eval() +# # cam_computer = CAM(model, target_layers=[args.target_layer, args.final_layer]) +# cam_computer = GradCAM(model, target_layers=args.target_layer) +# resize_param = (224, 224) +# norm_mean = [0.5528, 0.5528, 0.5528] +# norm_std = [0.1583, 0.1583, 0.1583] +# disp_size = 10 +# preprocess = transforms.Compose( +# [transforms.Resize(resize_param), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)] +# ) +# plt.figure(figsize=(disp_size, disp_size)) +# for i, file in enumerate(glob.glob("./temp/test_images/*")): +# image = PIL.Image.open(file) +# h, w, b = np.shape(np.array(image)) +# img_tensor = preprocess(image).unsqueeze(0).to(device) +# cam_img = cam_computer(img_tensor)[0].detach().cpu().numpy() +# img = np.array(image) +# cam_img = cv2.resize(cam_img, (h, w), interpolation=cv2.INTER_CUBIC) +# cam_img = np.uint8(cam_img * 255) +# height, width, _ = img.shape +# heatmap = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET) +# result = heatmap * 0.3 + img * 0.6 +# plt.subplot(2, 1, i + 1) +# plt.imshow(result.astype(np.int)) +# +# plt.show() From 6790d801d7661717dbf0bfe487f03735de089279 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Nov 2020 18:55:31 +0000 Subject: [PATCH 2/6] update based on comments Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 150 ++++++++++++++++++----- 1 file changed, 117 insertions(+), 33 deletions(-) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 49a0bb71b0..b11995c1eb 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -9,22 +9,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Dict + +import numpy as np import torch +import torch.nn.functional as F -from monai.utils import ensure_tuple +from monai.transforms import ScaleIntensity +from monai.utils import InterpolateMode, ensure_tuple class ModelWithHooks: """ - A model wrapper to run model forward and store intermediate forward, backward information. + A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. """ - def __init__(self, model, target_layers, register_forward: bool = False, register_backward: bool = False): + def __init__(self, model, target_layer_names, register_forward: bool = False, register_backward: bool = False): + """ + + Args: + model: the model to be wrapped. + target_layer_names: the names of the layer to + register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. + register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. + """ self.model = model - self.target_layers = ensure_tuple(target_layers) + self.target_layers = ensure_tuple(target_layer_names) - self.gradients = {} - self.activations = {} + self.gradients: Dict[str, torch.Tensor] = {} + self.activations: Dict[str, torch.Tensor] = {} self.register_backward = register_backward self.register_forward = register_forward @@ -66,54 +79,125 @@ def __call__(self, x, class_idx=None, retain_graph=False): return logits, acti, grad +def default_upsampler(spatial_size) -> Callable[[torch.Tensor], torch.Tensor]: + """ + A linear interpolation method for upsampling the feature map. + The output of this function is a callable `func`, + such that `func(activation_map)` returns an upsampled tensor. + """ + + def up(acti_map): + linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR] + interp_mode = linear_mode[len(spatial_size) - 1] + return F.interpolate(acti_map, size=spatial_size, mode=str(interp_mode.value), align_corners=False) + + return up + + +def default_normalizer(acti_map) -> np.ndarray: + """ + A linear intensity scaling by mapping the (min, max) to (1, 0). + """ + if isinstance(acti_map, torch.Tensor): + acti_map = acti_map.detach().cpu().numpy() + scaler = ScaleIntensity(minv=1.0, maxv=0.0) + acti_map = [scaler(x) for x in acti_map] + return np.stack(acti_map, axis=0) + + class CAM: - def __init__(self, model, target_layers): + def __init__( + self, + model, + target_layers, + fc_weights=lambda m: m.fc.weight, + upsampler=default_upsampler, + postprocessing: Callable = default_normalizer, + ): + """ + + Args: + model: the model to be visualised + target_layers: name of the model layer to generate the feature map. + fc_weights: a callable used to get fully-connected weights to compute activation map + from the target_layers (without pooling). The default is `lambda m: m.fc.weight`, that is + getting the fully-connected layer by `model.fc.weight`. + upsampler: an upsampling method to upsample the feature map. + postprocessing: a callable that applies on the upsampled feature map. + """ self.net = ModelWithHooks(model, target_layers, register_forward=True) + self.upsampler = upsampler + self.postprocessing = postprocessing + self.fc_weights = fc_weights - def norm_features(self, feature): - feature -= feature.min() - feature /= feature.max() - return 1.0 - feature - - def __call__(self, x, class_idx=None): + def compute_map(self, x, class_idx=None): logits, acti, _ = self.net(x) acti = acti[0] if class_idx is None: class_idx = torch.argmax(logits, dim=1) b, c, *spatial = acti.shape - weight = self.net.model.fc.weight[class_idx].view(c, *[1 for _ in spatial]) - map = [self.norm_features((weight * a).sum(0)) for a in acti] - return torch.stack(map, dim=0) + weights = self.fc_weights(self.net.model) + weights = weights[class_idx].view(c, *[1 for _ in spatial]) + return (weights * acti).sum(1, keepdim=True) + + def feature_map_size(self, input_size, device="cpu"): + return self.compute_map(torch.zeros(*input_size, device=device)).shape + + def __call__(self, x, class_idx=None): + acti_map = self.compute_map(x, class_idx) + if self.upsampler: + img_spatial = x.shape[2:] + acti_map = self.upsampler(img_spatial)(acti_map) + if self.postprocessing: + acti_map = self.postprocessing(acti_map) + return acti_map class GradCAM: - def __init__(self, model, target_layers): + def __init__(self, model, target_layers, upsampler=default_upsampler, postprocessing=default_normalizer): + """ + + Args: + model: the model to be visualised + target_layers: name of the model layer to generate the feature map. + upsampler: an upsampling method to upsample the feature map. + postprocessing: a callable that applies on the upsampled feature map. + """ self.net = ModelWithHooks(model, target_layers, register_forward=True, register_backward=True) + self.upsampler = upsampler + self.postprocessing = postprocessing - def norm_features(self, feature): - feature -= feature.min() - feature /= feature.max() - return 1.0 - feature - - def __call__(self, x): - logits, acti, grad = self.net(x) - acti = acti[0] - grad = grad[0] + def compute_map(self, x, class_idx=None, retain_graph=False): + logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + acti, grad = acti[0], grad[0] b, c, *spatial = grad.shape grad_ave = grad.view(b, c, -1).mean(2) - weights = grad_ave.view(b, c, 1, 1) - map = (weights * acti).sum(1) - return self.norm_features(map) + weights = grad_ave.view(b, c, *[1 for _ in spatial]) + acti_map = (weights * acti).sum(1, keepdim=True) + return F.relu(acti_map) + + def feature_map_size(self, input_size, device="cpu"): + return self.compute_map(torch.zeros(*input_size, device=device)).shape + + def __call__(self, x, class_idx=None, retain_graph=False): + acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph) + if self.upsampler: + img_spatial = x.shape[2:] + acti_map = self.upsampler(img_spatial)(acti_map) + if self.postprocessing: + acti_map = self.postprocessing(acti_map) + return acti_map # if __name__ == "__main__": -# from torchvision import transforms +# import argparse # import glob +# +# import cv2 # import numpy as np # import PIL -# import cv2 # from matplotlib import pyplot as plt -# import argparse +# from torchvision import transforms # # parser = argparse.ArgumentParser() # parser.add_argument("--device", default="0", type=str) @@ -147,7 +231,7 @@ def __call__(self, x): # image = PIL.Image.open(file) # h, w, b = np.shape(np.array(image)) # img_tensor = preprocess(image).unsqueeze(0).to(device) -# cam_img = cam_computer(img_tensor)[0].detach().cpu().numpy() +# cam_img = cam_computer(img_tensor)[0, 0] # img = np.array(image) # cam_img = cv2.resize(cam_img, (h, w), interpolation=cv2.INTER_CUBIC) # cam_img = np.uint8(cam_img * 255) From d23d6521dc79c7035b0cafaed98412e73efbd1a8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 30 Nov 2020 20:04:41 +0000 Subject: [PATCH 3/6] update docstring Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 69 +++++++++++++++--------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index b11995c1eb..e1c1d62531 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -29,7 +29,7 @@ def __init__(self, model, target_layer_names, register_forward: bool = False, re Args: model: the model to be wrapped. - target_layer_names: the names of the layer to + target_layer_names: the names of the layer to cache. register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. """ @@ -106,11 +106,15 @@ def default_normalizer(acti_map) -> np.ndarray: class CAM: + """ + Compute class activation map from the last fully-connected layers before the spatial pooling. + """ + def __init__( self, model, target_layers, - fc_weights=lambda m: m.fc.weight, + fc_layers=lambda m: m.fc, upsampler=default_upsampler, postprocessing: Callable = default_normalizer, ): @@ -119,32 +123,36 @@ def __init__( Args: model: the model to be visualised target_layers: name of the model layer to generate the feature map. - fc_weights: a callable used to get fully-connected weights to compute activation map - from the target_layers (without pooling). The default is `lambda m: m.fc.weight`, that is - getting the fully-connected layer by `model.fc.weight`. + fc_layers: a callable used to get fully-connected weights to compute activation map + from the target_layers (without pooling). The default is `lambda m: m.fc`, that is + to get the fully-connected layer by `model.fc` and evaluate it at every spatial location. upsampler: an upsampling method to upsample the feature map. postprocessing: a callable that applies on the upsampled feature map. """ self.net = ModelWithHooks(model, target_layers, register_forward=True) self.upsampler = upsampler self.postprocessing = postprocessing - self.fc_weights = fc_weights + self.fc_layers = fc_layers - def compute_map(self, x, class_idx=None): + def compute_map(self, x, class_idx=None, layer_idx=-1): logits, acti, _ = self.net(x) - acti = acti[0] + acti = acti[layer_idx] if class_idx is None: class_idx = torch.argmax(logits, dim=1) b, c, *spatial = acti.shape - weights = self.fc_weights(self.net.model) - weights = weights[class_idx].view(c, *[1 for _ in spatial]) - return (weights * acti).sum(1, keepdim=True) + acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D + fc_layers = self.fc_layers(self.net.model) + output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2) + output = output[:, class_idx : class_idx + 1] # only retain the spatial map of the selected class + return output.reshape(b, -1, *spatial) # resume the spatial dims on the selected class - def feature_map_size(self, input_size, device="cpu"): - return self.compute_map(torch.zeros(*input_size, device=device)).shape + def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None): - acti_map = self.compute_map(x, class_idx) + def __call__(self, x, class_idx=None, layer_idx=-1): + acti_map = self.compute_map(x, class_idx, layer_idx) + + # upsampling and postprocessing if self.upsampler: img_spatial = x.shape[2:] acti_map = self.upsampler(img_spatial)(acti_map) @@ -154,11 +162,20 @@ def __call__(self, x, class_idx=None): class GradCAM: + """ + Computes Gradient-weighted Class Activation Mapping (Grad-CAM). + This implementation is based on: + + Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, + https://arxiv.org/abs/1610.02391 + + """ + def __init__(self, model, target_layers, upsampler=default_upsampler, postprocessing=default_normalizer): """ Args: - model: the model to be visualised + model: the model to be used to generate the visualisations. target_layers: name of the model layer to generate the feature map. upsampler: an upsampling method to upsample the feature map. postprocessing: a callable that applies on the upsampled feature map. @@ -167,20 +184,22 @@ def __init__(self, model, target_layers, upsampler=default_upsampler, postproces self.upsampler = upsampler self.postprocessing = postprocessing - def compute_map(self, x, class_idx=None, retain_graph=False): + def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) - acti, grad = acti[0], grad[0] + acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape grad_ave = grad.view(b, c, -1).mean(2) - weights = grad_ave.view(b, c, *[1 for _ in spatial]) + weights = grad_ave.view(b, c, [1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) - def feature_map_size(self, input_size, device="cpu"): - return self.compute_map(torch.zeros(*input_size, device=device)).shape + def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape - def __call__(self, x, class_idx=None, retain_graph=False): - acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph) + def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): + acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) + + # upsampling and postprocessing if self.upsampler: img_spatial = x.shape[2:] acti_map = self.upsampler(img_spatial)(acti_map) @@ -217,8 +236,8 @@ def __call__(self, x, class_idx=None, retain_graph=False): # if torch.cuda.is_available(): # model.cuda() # model.eval() -# # cam_computer = CAM(model, target_layers=[args.target_layer, args.final_layer]) -# cam_computer = GradCAM(model, target_layers=args.target_layer) +# cam_computer = CAM(model, target_layers=args.target_layer) +# # cam_computer = GradCAM(model, target_layers=args.target_layer) # resize_param = (224, 224) # norm_mean = [0.5528, 0.5528, 0.5528] # norm_std = [0.1583, 0.1583, 0.1583] From 2ab243c46f29c08ce3691cbb63d7e6e1bd43dfda Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 1 Dec 2020 17:10:50 +0000 Subject: [PATCH 4/6] adds cpu tests Signed-off-by: Wenqi Li --- docs/source/visualize.rst | 8 + monai/visualize/__init__.py | 1 + monai/visualize/class_activation_maps.py | 244 +++++++++++++++-------- tests/test_vis_cam.py | 92 +++++++++ tests/test_vis_gradcam.py | 88 ++++++++ 5 files changed, 353 insertions(+), 80 deletions(-) create mode 100644 tests/test_vis_cam.py create mode 100644 tests/test_vis_gradcam.py diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst index e6506f6849..9668d48114 100644 --- a/docs/source/visualize.rst +++ b/docs/source/visualize.rst @@ -5,8 +5,16 @@ Visualizations ============== +.. currentmodule:: monai.visualize + Tensorboard visuals ------------------- .. automodule:: monai.visualize.img2tensorboard :members: + +Class activation map +-------------------- + +.. automodule:: monai.visualize.class_activation_maps + :members: \ No newline at end of file diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index ea1a43ff47..2fbd1dcf66 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -9,4 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .class_activation_maps import * from .img2tensorboard import * diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index e1c1d62531..02467abb89 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Sequence, Union import numpy as np import torch @@ -18,36 +19,50 @@ from monai.transforms import ScaleIntensity from monai.utils import InterpolateMode, ensure_tuple +__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM"] + class ModelWithHooks: """ A model wrapper to run model forward/backward steps and storing some intermediate feature/gradient information. """ - def __init__(self, model, target_layer_names, register_forward: bool = False, register_backward: bool = False): + def __init__( + self, + nn_module, + target_layer_names: Union[str, Sequence[str]], + register_forward: bool = False, + register_backward: bool = False, + ): """ Args: - model: the model to be wrapped. + nn_module: the model to be wrapped. target_layer_names: the names of the layer to cache. register_forward: whether to cache the forward pass output corresponding to `target_layer_names`. register_backward: whether to cache the backward pass output corresponding to `target_layer_names`. """ - self.model = model + self.model = nn_module self.target_layers = ensure_tuple(target_layer_names) self.gradients: Dict[str, torch.Tensor] = {} self.activations: Dict[str, torch.Tensor] = {} + self.score = None + self.class_idx = None self.register_backward = register_backward self.register_forward = register_forward - for name, mod in model.named_modules(): + _registered = [] + for name, mod in nn_module.named_modules(): if name not in self.target_layers: continue + _registered.append(name) if self.register_backward: mod.register_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) + if len(_registered) != len(self.target_layers): + warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") def backward_hook(self, name): def _hook(_module, _grad_input, grad_output): @@ -61,10 +76,29 @@ def _hook(_module, _input, output): return _hook + def get_layer(self, layer_id: Union[str, Callable]): + """ + + Args: + layer_id: a layer name string or a callable. If it is a callable such as `lambda m: m.fc`, + this method will return the module `self.model.fc`. + + Returns: + a submodule from self.model. + """ + if callable(layer_id): + return layer_id(self.model) + if isinstance(layer_id, str): + for name, mod in self.model.named_modules(): + if name == layer_id: + return mod + raise NotImplementedError(f"Could not find {layer_id}.") + def class_score(self, logits, class_idx=None): if class_idx is not None: - return logits[:, class_idx].squeeze() - return logits[:, logits.argmax(1)].squeeze() + return logits[:, class_idx].squeeze(), class_idx + class_idx = logits.argmax(1) + return logits[:, class_idx].squeeze(), class_idx def __call__(self, x, class_idx=None, retain_graph=False): logits = self.model(x) @@ -72,9 +106,10 @@ def __call__(self, x, class_idx=None, retain_graph=False): if self.register_forward: acti = tuple(self.activations[layer] for layer in self.target_layers) if self.register_backward: - score = self.class_score(logits, class_idx) + score, class_idx = self.class_score(logits, class_idx) self.model.zero_grad() - score.backward(retain_graph=retain_graph) + self.score, self.class_idx = score, class_idx + score.sum().backward(retain_graph=retain_graph) grad = tuple(self.gradients[layer] for layer in self.target_layers) return logits, acti, grad @@ -108,48 +143,100 @@ def default_normalizer(acti_map) -> np.ndarray: class CAM: """ Compute class activation map from the last fully-connected layers before the spatial pooling. + + Examples + + .. code-block:: python + + # densenet 2d + from monai.networks.nets import densenet121 + from monai.visualize import CAM + + model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + cam = CAM(nn_module=model_2d, target_layers="class_layers.relu", fc_layers="class_layers.out") + result = cam(x=torch.rand((1, 1, 48, 64))) + + # resnet 2d + from monai.networks.nets import se_resnet50 + from monai.visualize import CAM + + model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + cam = CAM(nn_module=model_2d, target_layers="layer4", fc_layers="last_linear") + result = cam(x=torch.rand((2, 3, 48, 64))) + + See Also: + + - :py:class:`monai.visualize.class_activation_maps.GradCAM` + """ def __init__( self, - model, - target_layers, - fc_layers=lambda m: m.fc, + nn_module, + target_layers: str, + fc_layers: Union[str, Callable] = "fc", upsampler=default_upsampler, postprocessing: Callable = default_normalizer, ): """ Args: - model: the model to be visualised + nn_module: the model to be visualised target_layers: name of the model layer to generate the feature map. - fc_layers: a callable used to get fully-connected weights to compute activation map - from the target_layers (without pooling). The default is `lambda m: m.fc`, that is - to get the fully-connected layer by `model.fc` and evaluate it at every spatial location. + fc_layers: a string or a callable used to get fully-connected weights to compute activation map + from the target_layers (without pooling). and evaluate it at every spatial location. upsampler: an upsampling method to upsample the feature map. postprocessing: a callable that applies on the upsampled feature map. """ - self.net = ModelWithHooks(model, target_layers, register_forward=True) + if not isinstance(nn_module, ModelWithHooks): + self.net = ModelWithHooks(nn_module, target_layers, register_forward=True) + else: + self.net = nn_module self.upsampler = upsampler self.postprocessing = postprocessing self.fc_layers = fc_layers def compute_map(self, x, class_idx=None, layer_idx=-1): + """ + Compute the actual feature map with input tensor `x`. + """ logits, acti, _ = self.net(x) acti = acti[layer_idx] if class_idx is None: class_idx = torch.argmax(logits, dim=1) b, c, *spatial = acti.shape acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D - fc_layers = self.fc_layers(self.net.model) + fc_layers = self.net.get_layer(self.fc_layers) output = torch.stack([fc_layers(a[..., 0]) for a in acti], dim=2) - output = output[:, class_idx : class_idx + 1] # only retain the spatial map of the selected class - return output.reshape(b, -1, *spatial) # resume the spatial dims on the selected class + output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0) + return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + """ + Computes the actual feature map size given `nn_module` and the target_layer name. + + Args: + input_size: shape of the input tensor + device: the device used to initialise the input tensor + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + + Returns: + shape of the actual feature map. + """ return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape def __call__(self, x, class_idx=None, layer_idx=-1): + """ + Compute the activation map with upsampling and postprocessing. + + Args: + x: input tensor, shape must be compatible with `nn_module`. + class_idx: index of the class to be visualised. Default to argmax(logits) + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + + Returns: + activation maps + """ acti_map = self.compute_map(x, class_idx, layer_idx) # upsampling and postprocessing @@ -169,34 +256,86 @@ class GradCAM: Selvaraju et al., Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, https://arxiv.org/abs/1610.02391 + Examples + + .. code-block:: python + + # densenet 2d + from monai.networks.nets import densenet121 + from monai.visualize import GradCAM + + model_2d = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + cam = GradCAM(nn_module=model_2d, target_layers="class_layers.relu") + result = cam(x=torch.rand((1, 1, 48, 64))) + + # resnet 2d + from monai.networks.nets import se_resnet50 + from monai.visualize import GradCAM + + model_2d = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + cam = GradCAM(nn_module=model_2d, target_layers="layer4") + result = cam(x=torch.rand((2, 3, 48, 64))) + + See Also: + + - :py:class:`monai.visualize.class_activation_maps.CAM` + """ - def __init__(self, model, target_layers, upsampler=default_upsampler, postprocessing=default_normalizer): + def __init__(self, nn_module, target_layers: str, upsampler=default_upsampler, postprocessing=default_normalizer): """ Args: - model: the model to be used to generate the visualisations. + nn_module: the model to be used to generate the visualisations. target_layers: name of the model layer to generate the feature map. upsampler: an upsampling method to upsample the feature map. postprocessing: a callable that applies on the upsampled feature map. """ - self.net = ModelWithHooks(model, target_layers, register_forward=True, register_backward=True) + if not isinstance(nn_module, ModelWithHooks): + self.net = ModelWithHooks(nn_module, target_layers, register_forward=True, register_backward=True) + else: + self.net = nn_module self.upsampler = upsampler self.postprocessing = postprocessing def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): + """ + Compute the actual feature map with input tensor `x`. + """ logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) acti, grad = acti[layer_idx], grad[layer_idx] b, c, *spatial = grad.shape - grad_ave = grad.view(b, c, -1).mean(2) - weights = grad_ave.view(b, c, [1] * len(spatial)) + weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial)) acti_map = (weights * acti).sum(1, keepdim=True) return F.relu(acti_map) def feature_map_size(self, input_size, device="cpu", layer_idx=-1): + """ + Computes the actual feature map size given `nn_module` and the target_layer name. + + Args: + input_size: shape of the input tensor + device: the device used to initialise the input tensor + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + + Returns: + shape of the actual feature map. + """ return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): + """ + Compute the activation map with upsampling and postprocessing. + + Args: + x: input tensor, shape must be compatible with `nn_module`. + class_idx: index of the class to be visualised. Default to argmax(logits) + layer_idx: index of the target layer if there are multiple target layers. Defaults to -1. + retain_graph: whether to retain_graph for torch module backward call. + + Returns: + activation maps + """ acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx) # upsampling and postprocessing @@ -206,58 +345,3 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): if self.postprocessing: acti_map = self.postprocessing(acti_map) return acti_map - - -# if __name__ == "__main__": -# import argparse -# import glob -# -# import cv2 -# import numpy as np -# import PIL -# from matplotlib import pyplot as plt -# from torchvision import transforms -# -# parser = argparse.ArgumentParser() -# parser.add_argument("--device", default="0", type=str) -# parser.add_argument( -# "--target_layer", -# default="layer4", -# type=str, -# help="Specify the name of the target layer (before the pooling layer)", -# ) -# parser.add_argument( -# "--final_layer", default="fc", type=str, help="Specify the name of the last classification layer" -# ) -# args = parser.parse_args() -# device = torch.device("cuda:" + args.device) if torch.cuda.is_available() else torch.device("cpu") -# model = torch.load("temp/resnet-cam.pt", map_location=device) -# # print(model) -# if torch.cuda.is_available(): -# model.cuda() -# model.eval() -# cam_computer = CAM(model, target_layers=args.target_layer) -# # cam_computer = GradCAM(model, target_layers=args.target_layer) -# resize_param = (224, 224) -# norm_mean = [0.5528, 0.5528, 0.5528] -# norm_std = [0.1583, 0.1583, 0.1583] -# disp_size = 10 -# preprocess = transforms.Compose( -# [transforms.Resize(resize_param), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std)] -# ) -# plt.figure(figsize=(disp_size, disp_size)) -# for i, file in enumerate(glob.glob("./temp/test_images/*")): -# image = PIL.Image.open(file) -# h, w, b = np.shape(np.array(image)) -# img_tensor = preprocess(image).unsqueeze(0).to(device) -# cam_img = cam_computer(img_tensor)[0, 0] -# img = np.array(image) -# cam_img = cv2.resize(cam_img, (h, w), interpolation=cv2.INTER_CUBIC) -# cam_img = np.uint8(cam_img * 255) -# height, width, _ = img.shape -# heatmap = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET) -# result = heatmap * 0.3 + img * 0.6 -# plt.subplot(2, 1, i + 1) -# plt.imshow(result.astype(np.int)) -# -# plt.show() diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py new file mode 100644 index 0000000000..ecebe91117 --- /dev/null +++ b/tests/test_vis_cam.py @@ -0,0 +1,92 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet, densenet121, se_resnet50 +from monai.visualize import CAM + +# 2D +TEST_CASE_0 = [ + { + "model": "densenet2d", + "shape": (2, 1, 48, 64), + "feature_shape": (2, 1, 1, 2), + "target_layers": "class_layers.relu", + "fc_layers": "class_layers.out", + }, + (2, 1, 48, 64), +] +# 3D +TEST_CASE_1 = [ + { + "model": "densenet3d", + "shape": (2, 1, 6, 6, 6), + "feature_shape": (2, 1, 2, 2, 2), + "target_layers": "class_layers.relu", + "fc_layers": "class_layers.out", + }, + (2, 1, 6, 6, 6), +] +# 2D +TEST_CASE_2 = [ + { + "model": "senet2d", + "shape": (2, 3, 64, 64), + "feature_shape": (2, 1, 2, 2), + "target_layers": "layer4", + "fc_layers": "last_linear", + }, + (2, 1, 64, 64), +] + +# 3D +TEST_CASE_3 = [ + { + "model": "senet3d", + "shape": (2, 3, 8, 8, 48), + "feature_shape": (2, 1, 1, 1, 2), + "target_layers": "layer4", + "fc_layers": "last_linear", + }, + (2, 1, 8, 8, 48), +] + + +class TestClassActivationMap(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_data, expected_shape): + if input_data["model"] == "densenet2d": + model = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + if input_data["model"] == "densenet3d": + model = DenseNet( + spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) + ) + if input_data["model"] == "senet2d": + model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + if input_data["model"] == "senet3d": + model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + cam = CAM(nn_module=model, target_layers=input_data["target_layers"], fc_layers=input_data["fc_layers"]) + image = torch.rand(input_data["shape"], device=device) + result = cam(x=image, layer_idx=-1) + fea_shape = cam.feature_map_size(input_data["shape"]) + self.assertTupleEqual(fea_shape, input_data["feature_shape"]) + self.assertTupleEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py new file mode 100644 index 0000000000..c20605459d --- /dev/null +++ b/tests/test_vis_gradcam.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet, densenet121, se_resnet50 +from monai.visualize import GradCAM + +# 2D +TEST_CASE_0 = [ + { + "model": "densenet2d", + "shape": (2, 1, 48, 64), + "feature_shape": (2, 1, 1, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 48, 64), +] +# 3D +TEST_CASE_1 = [ + { + "model": "densenet3d", + "shape": (2, 1, 6, 6, 6), + "feature_shape": (2, 1, 2, 2, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 6, 6, 6), +] +# 2D +TEST_CASE_2 = [ + { + "model": "senet2d", + "shape": (2, 3, 64, 64), + "feature_shape": (2, 1, 2, 2), + "target_layers": "layer4", + }, + (2, 1, 64, 64), +] + +# 3D +TEST_CASE_3 = [ + { + "model": "senet3d", + "shape": (2, 3, 8, 8, 48), + "feature_shape": (2, 1, 1, 1, 2), + "target_layers": "layer4", + }, + (2, 1, 8, 8, 48), +] + + +class TestGradientClassActivationMap(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_data, expected_shape): + if input_data["model"] == "densenet2d": + model = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + if input_data["model"] == "densenet3d": + model = DenseNet( + spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) + ) + if input_data["model"] == "senet2d": + model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + if input_data["model"] == "senet3d": + model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + cam = GradCAM(nn_module=model, target_layers=input_data["target_layers"]) + image = torch.rand(input_data["shape"], device=device) + result = cam(x=image, layer_idx=-1) + fea_shape = cam.feature_map_size(input_data["shape"]) + self.assertTupleEqual(fea_shape, input_data["feature_shape"]) + self.assertTupleEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From ac94a6919cbcaedd76287e73ee8106c1c096f5cf Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 1 Dec 2020 17:40:20 +0000 Subject: [PATCH 5/6] adds gradcam++ Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 33 ++++++++- tests/test_vis_cam.py | 2 +- tests/test_vis_gradcam.py | 2 +- tests/test_vis_gradcampp.py | 88 ++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 tests/test_vis_gradcampp.py diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 02467abb89..004d6c8c37 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -19,7 +19,7 @@ from monai.transforms import ScaleIntensity from monai.utils import InterpolateMode, ensure_tuple -__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM"] +__all__ = ["ModelWithHooks", "default_upsampler", "default_normalizer", "CAM", "GradCAM", "GradCAMpp"] class ModelWithHooks: @@ -345,3 +345,34 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False): if self.postprocessing: acti_map = self.postprocessing(acti_map) return acti_map + + +class GradCAMpp(GradCAM): + """ + Computes Gradient-weighted Class Activation Mapping (Grad-CAM++). + This implementation is based on: + + Chattopadhyay et al., Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks, + https://arxiv.org/abs/1710.11063 + + See Also: + + - :py:class:`monai.visualize.class_activation_maps.GradCAM` + + """ + + def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1): + """ + Compute the actual feature map with input tensor `x`. + """ + logits, acti, grad = self.net(x, class_idx=class_idx, retain_graph=retain_graph) + acti, grad = acti[layer_idx], grad[layer_idx] + b, c, *spatial = grad.shape + alpha_nr = grad.pow(2) + alpha_dr = alpha_nr.mul(2) + acti.mul(grad.pow(3)).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) + alpha_dr = torch.where(alpha_dr != 0.0, alpha_dr, torch.ones_like(alpha_dr)) + alpha = alpha_nr.div(alpha_dr + 1e-7) + relu_grad = F.relu(self.net.score.exp() * grad) + weights = (alpha * relu_grad).view(b, c, -1).sum(-1).view(b, c, *[1] * len(spatial)) + acti_map = (weights * acti).sum(1, keepdim=True) + return F.relu(acti_map) diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py index ecebe91117..e2ec119ec8 100644 --- a/tests/test_vis_cam.py +++ b/tests/test_vis_cam.py @@ -83,7 +83,7 @@ def test_shape(self, input_data, expected_shape): cam = CAM(nn_module=model, target_layers=input_data["target_layers"], fc_layers=input_data["fc_layers"]) image = torch.rand(input_data["shape"], device=device) result = cam(x=image, layer_idx=-1) - fea_shape = cam.feature_map_size(input_data["shape"]) + fea_shape = cam.feature_map_size(input_data["shape"], device=device) self.assertTupleEqual(fea_shape, input_data["feature_shape"]) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index c20605459d..3fb53b1fda 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -79,7 +79,7 @@ def test_shape(self, input_data, expected_shape): cam = GradCAM(nn_module=model, target_layers=input_data["target_layers"]) image = torch.rand(input_data["shape"], device=device) result = cam(x=image, layer_idx=-1) - fea_shape = cam.feature_map_size(input_data["shape"]) + fea_shape = cam.feature_map_size(input_data["shape"], device=device) self.assertTupleEqual(fea_shape, input_data["feature_shape"]) self.assertTupleEqual(result.shape, expected_shape) diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py new file mode 100644 index 0000000000..c6bdef1647 --- /dev/null +++ b/tests/test_vis_gradcampp.py @@ -0,0 +1,88 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet, densenet121, se_resnet50 +from monai.visualize import GradCAMpp + +# 2D +TEST_CASE_0 = [ + { + "model": "densenet2d", + "shape": (2, 1, 48, 64), + "feature_shape": (2, 1, 1, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 48, 64), +] +# 3D +TEST_CASE_1 = [ + { + "model": "densenet3d", + "shape": (2, 1, 6, 6, 6), + "feature_shape": (2, 1, 2, 2, 2), + "target_layers": "class_layers.relu", + }, + (2, 1, 6, 6, 6), +] +# 2D +TEST_CASE_2 = [ + { + "model": "senet2d", + "shape": (2, 3, 64, 64), + "feature_shape": (2, 1, 2, 2), + "target_layers": "layer4", + }, + (2, 1, 64, 64), +] + +# 3D +TEST_CASE_3 = [ + { + "model": "senet3d", + "shape": (2, 3, 8, 8, 48), + "feature_shape": (2, 1, 1, 1, 2), + "target_layers": "layer4", + }, + (2, 1, 8, 8, 48), +] + + +class TestGradientClassActivationMapPP(unittest.TestCase): + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) + def test_shape(self, input_data, expected_shape): + if input_data["model"] == "densenet2d": + model = densenet121(spatial_dims=2, in_channels=1, out_channels=3) + if input_data["model"] == "densenet3d": + model = DenseNet( + spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,) + ) + if input_data["model"] == "senet2d": + model = se_resnet50(spatial_dims=2, in_channels=3, num_classes=4) + if input_data["model"] == "senet3d": + model = se_resnet50(spatial_dims=3, in_channels=3, num_classes=4) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + cam = GradCAMpp(nn_module=model, target_layers=input_data["target_layers"]) + image = torch.rand(input_data["shape"], device=device) + result = cam(x=image, layer_idx=-1) + fea_shape = cam.feature_map_size(input_data["shape"], device=device) + self.assertTupleEqual(fea_shape, input_data["feature_shape"]) + self.assertTupleEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 1b01e544b7fd638794387c7c4678f535394f95b0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 1 Dec 2020 18:05:03 +0000 Subject: [PATCH 6/6] fixes compatibility issue Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 004d6c8c37..6fd29d1c96 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -97,7 +97,7 @@ def get_layer(self, layer_id: Union[str, Callable]): def class_score(self, logits, class_idx=None): if class_idx is not None: return logits[:, class_idx].squeeze(), class_idx - class_idx = logits.argmax(1) + class_idx = logits.max(1)[-1] return logits[:, class_idx].squeeze(), class_idx def __call__(self, x, class_idx=None, retain_graph=False): @@ -203,7 +203,7 @@ def compute_map(self, x, class_idx=None, layer_idx=-1): logits, acti, _ = self.net(x) acti = acti[layer_idx] if class_idx is None: - class_idx = torch.argmax(logits, dim=1) + class_idx = logits.max(1)[-1] b, c, *spatial = acti.shape acti = torch.split(acti.reshape(b, c, -1), 1, dim=2) # make the spatial dims 1D fc_layers = self.net.get_layer(self.fc_layers)