Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer
from .gradient_based import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad
from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image
from .occlusion_sensitivity import OcclusionSensitivity
from .utils import blend_images, matshow3d
Expand Down
6 changes: 3 additions & 3 deletions monai/visualize/class_activation_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
mod.register_backward_hook(self.backward_hook(name))
if self.register_forward:
mod.register_forward_hook(self.forward_hook(name))
if len(_registered) != len(self.target_layers):
if self.target_layers and (len(_registered) != len(self.target_layers)):
warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.")

def backward_hook(self, name):
Expand Down Expand Up @@ -139,10 +139,10 @@ def __call__(self, x, class_idx=None, retain_graph=False):
self.score.sum().backward(retain_graph=retain_graph)
for layer in self.target_layers:
if layer not in self.gradients:
raise RuntimeError(
warnings.warn(
f"Backward hook for {layer} is not triggered; `requires_grad` of {layer} should be `True`."
)
grad = tuple(self.gradients[layer] for layer in self.target_layers)
grad = tuple(self.gradients[layer] for layer in self.target_layers if layer in self.gradients)
if train:
self.model.train()
return logits, acti, grad
Expand Down
137 changes: 137 additions & 0 deletions monai/visualize/gradient_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from functools import partial
from typing import Callable

import torch

from monai.networks.utils import replace_modules_temp
from monai.utils.module import optional_import
from monai.visualize.class_activation_maps import ModelWithHooks

trange, has_trange = optional_import("tqdm", name="trange")


__all__ = ["VanillaGrad", "SmoothGrad", "GuidedBackpropGrad", "GuidedBackpropSmoothGrad"]


class _AutoGradReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
pos_mask = (x > 0).type_as(x)
output = torch.mul(x, pos_mask)
ctx.save_for_backward(x, output)
return output

@staticmethod
def backward(ctx, grad_output):
x, _ = ctx.saved_tensors
pos_mask_1 = (x > 0).type_as(grad_output)
pos_mask_2 = (grad_output > 0).type_as(grad_output)
y = torch.mul(grad_output, pos_mask_1)
grad_input = torch.mul(y, pos_mask_2)
return grad_input


class _GradReLU(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = _AutoGradReLU.apply(x)
return out


class VanillaGrad:
def __init__(self, model: torch.nn.Module) -> None:
if not isinstance(model, ModelWithHooks): # Convert to model with hooks if necessary
self._model = ModelWithHooks(model, target_layer_names=(), register_backward=True)
else:
self._model = model

@property
def model(self):
return self._model.model

@model.setter
def model(self, m):
if not isinstance(m, ModelWithHooks): # regular model as ModelWithHooks
self._model.model = m
else:
self._model = m # replace the ModelWithHooks

def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True) -> torch.Tensor:
if x.shape[0] != 1:
raise ValueError("expect batch size of 1")
x.requires_grad = True

self._model(x, class_idx=index, retain_graph=retain_graph)
grad: torch.Tensor = x.grad.detach()
return grad

def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
return self.get_grad(x, index)


class SmoothGrad(VanillaGrad):
"""
See also:
- Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825
"""

def __init__(
self,
model: torch.nn.Module,
stdev_spread: float = 0.15,
n_samples: int = 25,
magnitude: bool = True,
verbose: bool = True,
) -> None:
super().__init__(model)
self.stdev_spread = stdev_spread
self.n_samples = n_samples
self.magnitude = magnitude
self.range: Callable
if verbose and has_trange:
self.range = partial(trange, desc=f"Computing {self.__class__.__name__}")
else:
self.range = range

def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
stdev = (self.stdev_spread * (x.max() - x.min())).item()
total_gradients = torch.zeros_like(x)
for _ in self.range(self.n_samples):
# create noisy image
noise = torch.normal(0, stdev, size=x.shape, dtype=torch.float32, device=x.device)
x_plus_noise = x + noise
x_plus_noise = x_plus_noise.detach()

# get gradient and accumulate
grad = self.get_grad(x_plus_noise, index)
total_gradients += (grad * grad) if self.magnitude else grad

# average
if self.magnitude:
total_gradients = total_gradients**0.5

return total_gradients / self.n_samples


class GuidedBackpropGrad(VanillaGrad):
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
return super().__call__(x, index)


class GuidedBackpropSmoothGrad(SmoothGrad):
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
return super().__call__(x, index)
50 changes: 50 additions & 0 deletions tests/test_vis_gradbased.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad

DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))
SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)

TESTS = []
for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad):
# 2D densenet
TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)])
# 3D densenet
TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)])
# 2D senet
TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)])
# 3D senet
TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)])


class TestGradientClassActivationMap(unittest.TestCase):
@parameterized.expand(TESTS)
def test_shape(self, vis_type, model, shape, expected_shape):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
vis = vis_type(model)
x = torch.rand(shape, device=device)
result = vis(x)
self.assertTupleEqual(result.shape, x.shape)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_vis_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_ill(self):
x.requires_grad = False
cam = GradCAM(nn_module=model, target_layers="class_layers.relu")
image = torch.rand((2, 1, 48, 64))
with self.assertRaises(RuntimeError):
with self.assertRaises(IndexError):
cam(x=image)


Expand Down