diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index cd980846b3..49a628b66f 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer +from .gradient_based import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image from .occlusion_sensitivity import OcclusionSensitivity from .utils import blend_images, matshow3d diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 16fb64cb46..ba1f5d2589 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -89,7 +89,7 @@ def __init__( mod.register_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) - if len(_registered) != len(self.target_layers): + if self.target_layers and (len(_registered) != len(self.target_layers)): warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") def backward_hook(self, name): @@ -139,10 +139,10 @@ def __call__(self, x, class_idx=None, retain_graph=False): self.score.sum().backward(retain_graph=retain_graph) for layer in self.target_layers: if layer not in self.gradients: - raise RuntimeError( + warnings.warn( f"Backward hook for {layer} is not triggered; `requires_grad` of {layer} should be `True`." ) - grad = tuple(self.gradients[layer] for layer in self.target_layers) + grad = tuple(self.gradients[layer] for layer in self.target_layers if layer in self.gradients) if train: self.model.train() return logits, acti, grad diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py new file mode 100644 index 0000000000..32b8110b6d --- /dev/null +++ b/monai/visualize/gradient_based.py @@ -0,0 +1,137 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import partial +from typing import Callable + +import torch + +from monai.networks.utils import replace_modules_temp +from monai.utils.module import optional_import +from monai.visualize.class_activation_maps import ModelWithHooks + +trange, has_trange = optional_import("tqdm", name="trange") + + +__all__ = ["VanillaGrad", "SmoothGrad", "GuidedBackpropGrad", "GuidedBackpropSmoothGrad"] + + +class _AutoGradReLU(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + pos_mask = (x > 0).type_as(x) + output = torch.mul(x, pos_mask) + ctx.save_for_backward(x, output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, _ = ctx.saved_tensors + pos_mask_1 = (x > 0).type_as(grad_output) + pos_mask_2 = (grad_output > 0).type_as(grad_output) + y = torch.mul(grad_output, pos_mask_1) + grad_input = torch.mul(y, pos_mask_2) + return grad_input + + +class _GradReLU(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + out: torch.Tensor = _AutoGradReLU.apply(x) + return out + + +class VanillaGrad: + def __init__(self, model: torch.nn.Module) -> None: + if not isinstance(model, ModelWithHooks): # Convert to model with hooks if necessary + self._model = ModelWithHooks(model, target_layer_names=(), register_backward=True) + else: + self._model = model + + @property + def model(self): + return self._model.model + + @model.setter + def model(self, m): + if not isinstance(m, ModelWithHooks): # regular model as ModelWithHooks + self._model.model = m + else: + self._model = m # replace the ModelWithHooks + + def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True) -> torch.Tensor: + if x.shape[0] != 1: + raise ValueError("expect batch size of 1") + x.requires_grad = True + + self._model(x, class_idx=index, retain_graph=retain_graph) + grad: torch.Tensor = x.grad.detach() + return grad + + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + return self.get_grad(x, index) + + +class SmoothGrad(VanillaGrad): + """ + See also: + - Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825 + """ + + def __init__( + self, + model: torch.nn.Module, + stdev_spread: float = 0.15, + n_samples: int = 25, + magnitude: bool = True, + verbose: bool = True, + ) -> None: + super().__init__(model) + self.stdev_spread = stdev_spread + self.n_samples = n_samples + self.magnitude = magnitude + self.range: Callable + if verbose and has_trange: + self.range = partial(trange, desc=f"Computing {self.__class__.__name__}") + else: + self.range = range + + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + stdev = (self.stdev_spread * (x.max() - x.min())).item() + total_gradients = torch.zeros_like(x) + for _ in self.range(self.n_samples): + # create noisy image + noise = torch.normal(0, stdev, size=x.shape, dtype=torch.float32, device=x.device) + x_plus_noise = x + noise + x_plus_noise = x_plus_noise.detach() + + # get gradient and accumulate + grad = self.get_grad(x_plus_noise, index) + total_gradients += (grad * grad) if self.magnitude else grad + + # average + if self.magnitude: + total_gradients = total_gradients**0.5 + + return total_gradients / self.n_samples + + +class GuidedBackpropGrad(VanillaGrad): + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): + return super().__call__(x, index) + + +class GuidedBackpropSmoothGrad(SmoothGrad): + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): + return super().__call__(x, index) diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py new file mode 100644 index 0000000000..7655ca661e --- /dev/null +++ b/tests/test_vis_gradbased.py @@ -0,0 +1,50 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 +from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad + +DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) +DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) +SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) +SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) + +TESTS = [] +for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad): + # 2D densenet + TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)]) + # 3D densenet + TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)]) + # 2D senet + TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)]) + # 3D senet + TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)]) + + +class TestGradientClassActivationMap(unittest.TestCase): + @parameterized.expand(TESTS) + def test_shape(self, vis_type, model, shape, expected_shape): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + vis = vis_type(model) + x = torch.rand(shape, device=device) + result = vis(x) + self.assertTupleEqual(result.shape, x.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index acca06d405..755f4d49ae 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -85,7 +85,7 @@ def test_ill(self): x.requires_grad = False cam = GradCAM(nn_module=model, target_layers="class_layers.relu") image = torch.rand((2, 1, 48, 64)) - with self.assertRaises(RuntimeError): + with self.assertRaises(IndexError): cam(x=image)