From bb6e9e5cf9a3e45217ef8497f7db4a3b50bb6029 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 6 May 2022 10:53:59 +0100 Subject: [PATCH 01/12] gradient based saliency Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/__init__.py | 1 + monai/visualize/gradient_based.py | 158 ++++++++++++++++++++++++++++++ tests/test_vis_gradbased.py | 54 ++++++++++ 3 files changed, 213 insertions(+) create mode 100644 monai/visualize/gradient_based.py create mode 100644 tests/test_vis_gradbased.py diff --git a/monai/visualize/__init__.py b/monai/visualize/__init__.py index cd980846b3..49a628b66f 100644 --- a/monai/visualize/__init__.py +++ b/monai/visualize/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .class_activation_maps import CAM, GradCAM, GradCAMpp, ModelWithHooks, default_normalizer +from .gradient_based import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad from .img2tensorboard import add_animated_gif, make_animated_gif_summary, plot_2d_or_3d_image from .occlusion_sensitivity import OcclusionSensitivity from .utils import blend_images, matshow3d diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py new file mode 100644 index 0000000000..ba614ec69f --- /dev/null +++ b/monai/visualize/gradient_based.py @@ -0,0 +1,158 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from functools import partial +from typing import Callable + +import torch + +from monai.utils.module import optional_import + +trange, has_trange = optional_import("tqdm", name="trange") + + +__all__ = ["VanillaGrad", "SmoothGrad", "GuidedBackpropGrad", "GuidedBackpropSmoothGrad"] + + +def replace_module(parent: torch.nn.Module, name: str, new_module: torch.nn.Module) -> None: + idx = name.find(".") + if idx == -1: + setattr(parent, name, new_module) + else: + parent = getattr(parent, name[:idx]) + name = name[idx + 1 :] + replace_module(parent, name, new_module) + + +class _AutoGradReLU(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + pos_mask = (x > 0).type_as(x) + output = torch.mul(x, pos_mask) + ctx.save_for_backward(x, output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, _ = ctx.saved_tensors + pos_mask_1 = (x > 0).type_as(grad_output) + pos_mask_2 = (grad_output > 0).type_as(grad_output) + y = torch.mul(grad_output, pos_mask_1) + grad_input = torch.mul(y, pos_mask_2) + return grad_input + + +class _GradReLU(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + out: torch.Tensor = _AutoGradReLU().apply(x) + return out + + def backward(self, x: torch.Tensor) -> torch.Tensor: + out: torch.Tensor = _AutoGradReLU().backward(x) + return out + + +class VanillaGrad: + def __init__(self, model: torch.nn.Module) -> None: + self.model = model + + def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None) -> torch.Tensor: + if x.shape[0] != 1: + raise ValueError("expect batch size of 1") + x.requires_grad = True + + output: torch.Tensor = self.model(x) + + if index is None: + index = output.argmax().detach() + + num_classes = output.shape[-1] + one_hot = torch.zeros((1, num_classes), dtype=torch.float32, device=x.device) + one_hot[0][index] = 1 + one_hot.requires_grad = True + one_hot = torch.sum(one_hot * output) + + one_hot.backward(retain_graph=True) + grad: torch.Tensor = x.grad.detach() + return grad + + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + return self.get_grad(x, index) + + +class SmoothGrad(VanillaGrad): + def __init__( + self, + model: torch.nn.Module, + stdev_spread: float = 0.15, + n_samples: int = 25, + magnitude: bool = True, + verbose: bool = True, + ) -> None: + super().__init__(model) + self.stdev_spread = stdev_spread + self.n_samples = n_samples + self.magnitude = magnitude + self.range: Callable + if verbose and has_trange: + self.range = partial(trange, desc=f"Computing {self.__class__.__name__}") + else: + self.range = range + + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + stdev = (self.stdev_spread * (x.max() - x.min())).item() + total_gradients = torch.zeros_like(x) + for _ in self.range(self.n_samples): + # create noisy image + noise = torch.normal(0, stdev, size=x.shape, dtype=torch.float32, device=x.device) + x_plus_noise = x + noise + x_plus_noise = x_plus_noise.detach() + + # get gradient and accumulate + grad = self.get_grad(x_plus_noise, index) + total_gradients += (grad * grad) if self.magnitude else grad + + # average + return total_gradients / self.n_samples + + +@contextmanager +def replace_modules(model: torch.nn.Module, name_to_replace: str = "relu", replace_with: torch.nn.Module = _GradReLU): + # replace + to_replace = [] + try: + for name, module in model.named_modules(): + if name_to_replace in name: + to_replace.append((name, module)) + + for name, _ in to_replace: + replace_module(model, name, replace_with()) + + yield + finally: + # regardless of success or not, revert model + for name, module in to_replace: + replace_module(model, name, module) + + +class GuidedBackpropGrad(VanillaGrad): + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + with replace_modules(self.model): + return super().__call__(x, index) + + +class GuidedBackpropSmoothGrad(SmoothGrad): + def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: + with replace_modules(self.model): + return super().__call__(x, index) diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py new file mode 100644 index 0000000000..5bd6d0f100 --- /dev/null +++ b/tests/test_vis_gradbased.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet, DenseNet121, SEResNet50 +from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad + +DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) +DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)) +SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4) +SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4) + +TESTS = [] +for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad): + # 2D densenet + TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)]) + # 3D densenet + TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)]) + # 2D senet + TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)]) + # 3D senet + TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)]) + + +class TestGradientClassActivationMap(unittest.TestCase): + @parameterized.expand(TESTS) + def test_shape(self, vis_type, model, shape, expected_shape): + print(shape) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + vis = vis_type(model) + x = torch.rand(shape, device=device) + result = vis(x) + self.assertTupleEqual(result.shape, x.shape) + # check result is same whether class_idx=None is used or not + result2 = vis(x, index=model(x).argmax(1)) + torch.testing.assert_allclose(result, result2) + + +if __name__ == "__main__": + unittest.main() From e0d581dc1cb265d2e6cac984c9ec26b70aae4edb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 6 May 2022 12:18:32 +0100 Subject: [PATCH 02/12] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/gradient_based.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index ba614ec69f..e8e3e7135c 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -13,7 +13,7 @@ from contextlib import contextmanager from functools import partial -from typing import Callable +from typing import Callable, Type import torch @@ -124,11 +124,14 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> total_gradients += (grad * grad) if self.magnitude else grad # average + if self.magnitude: + total_gradients = total_gradients ** 0.5 + return total_gradients / self.n_samples @contextmanager -def replace_modules(model: torch.nn.Module, name_to_replace: str = "relu", replace_with: torch.nn.Module = _GradReLU): +def replace_modules(model: torch.nn.Module, name_to_replace: str = "relu", replace_with: Type[torch.nn.Module] = _GradReLU): # replace to_replace = [] try: From 6c089d4dfb01a78316e87a5bc2b9e96d7baf7e2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 May 2022 11:18:59 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/visualize/gradient_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index e8e3e7135c..f97cbf0ce2 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -131,7 +131,7 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> @contextmanager -def replace_modules(model: torch.nn.Module, name_to_replace: str = "relu", replace_with: Type[torch.nn.Module] = _GradReLU): +def replace_modules(model: torch.nn.Module, name_to_replace: str = "relu", replace_with: type[torch.nn.Module] = _GradReLU): # replace to_replace = [] try: From eeb68be37893fadd2a2acc009ccce3334213af1b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 6 May 2022 13:25:09 +0100 Subject: [PATCH 04/12] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/gradient_based.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index f97cbf0ce2..f0d929d8cd 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -125,13 +125,15 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> # average if self.magnitude: - total_gradients = total_gradients ** 0.5 + total_gradients = total_gradients**0.5 return total_gradients / self.n_samples @contextmanager -def replace_modules(model: torch.nn.Module, name_to_replace: str = "relu", replace_with: type[torch.nn.Module] = _GradReLU): +def replace_modules( + model: torch.nn.Module, name_to_replace: str = "relu", replace_with: type[torch.nn.Module] = _GradReLU +): # replace to_replace = [] try: From 544398f694c6b580832a195e8020b70172a411f9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 6 May 2022 13:33:03 +0100 Subject: [PATCH 05/12] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/gradient_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index f0d929d8cd..c01f5661ed 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -13,7 +13,7 @@ from contextlib import contextmanager from functools import partial -from typing import Callable, Type +from typing import Callable import torch From facf5cc658e645c6cba57529a97902ebbddee2b7 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 6 May 2022 15:59:57 +0100 Subject: [PATCH 06/12] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_vis_gradbased.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py index 5bd6d0f100..25dd530860 100644 --- a/tests/test_vis_gradbased.py +++ b/tests/test_vis_gradbased.py @@ -45,9 +45,6 @@ def test_shape(self, vis_type, model, shape, expected_shape): x = torch.rand(shape, device=device) result = vis(x) self.assertTupleEqual(result.shape, x.shape) - # check result is same whether class_idx=None is used or not - result2 = vis(x, index=model(x).argmax(1)) - torch.testing.assert_allclose(result, result2) if __name__ == "__main__": From 65edd5e6d6faa1a8833a513c2bc7375048c83d51 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 13:28:28 +0100 Subject: [PATCH 07/12] replace modules Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/__init__.py | 2 + monai/networks/utils.py | 104 +++++++++++++++++++++++++++++- monai/visualize/gradient_based.py | 37 +---------- tests/test_replace_module.py | 97 ++++++++++++++++++++++++++++ tests/test_vis_gradbased.py | 1 - 5 files changed, 205 insertions(+), 36 deletions(-) create mode 100644 tests/test_replace_module.py diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 76223dfaef..b563a89d8c 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -20,6 +20,8 @@ one_hot, pixelshuffle, predict_segmentation, + replace_module, + replace_module_temp, save_state, slice_channels, to_norm_affine, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f22be31524..fb37c8bdbf 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,8 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union +from copy import deepcopy +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -41,6 +42,8 @@ "save_state", "convert_to_torchscript", "meshgrid_ij", + "replace_module", + "replace_module_temp", ] @@ -551,3 +554,102 @@ def meshgrid_ij(*tensors): if pytorch_after(1, 10): return torch.meshgrid(*tensors, indexing="ij") return torch.meshgrid(*tensors) + + +def _replace_module( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + out: list, + strict_match: bool = True, + match_device: bool = True, +) -> None: + """ + Helper function for :py:class:`monai.networks.utils.replace_module`. + """ + if match_device: + devices = list({i.device for i in parent.parameters()}) + # if only one device for whole of model + if len(devices) == 1: + new_module.to(devices[0]) + idx = name.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = name[:idx] + parent = getattr(parent, parent_name) + name = name[idx + 1 :] + _out = [] + _replace_module(parent, name, new_module, _out) + # prepend the parent name + out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] + # no "." in module name, do the actual replacing + else: + if strict_match: + old_module = getattr(parent, name) + setattr(parent, name, new_module) + out += [(name, old_module)] + else: + for mod_name, _ in parent.named_modules(): + if name in mod_name: + _replace_module(parent, mod_name, deepcopy(new_module), out, strict_match=True) + + +def replace_module( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +) -> List[Tuple[str, torch.nn.Module]]: + """ + Replace sub-module(s) in a parent module. + + The name of the module to be replace can be nested e.g., + `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are "." + in the module name), then this function will recursively call itself. + + Args: + parent: module that contains the module to be replaced + name: name of module to be replaced. Can include ".". + new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will + be deep copied if `strict_match == False` multiple instances are independent. + strict_match: if `True`, module name must `== name`. If false then + `name in named_modules()` will be used. `True` can be used to change just + one module, whereas `False` can be used to replace all modules with similar + name (e.g., `relu`). + match_device: if `True`, the device of the new module will match the model. Requires all + of `parent` to be on the same device. + + Returns: + List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module. + + Raises: + AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. + """ + out = [] + _replace_module(parent, name, new_module, out, strict_match, match_device) + return out + + +@contextmanager +def replace_module_temp( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +) -> None: + """ + Temporarily replace sub-module(s) in a parent module (context manager). + + See :py:class:`monai.networks.utils.replace_module`. + """ + replaced = [] + try: + # replace + _replace_module(parent, name, new_module, replaced, strict_match, match_device) + yield + finally: + # revert + for name, module in replaced: + _replace_module(parent, name, module, [], strict_match=True, match_device=match_device) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index c01f5661ed..98f7095ad0 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -11,12 +11,12 @@ from __future__ import annotations -from contextlib import contextmanager from functools import partial from typing import Callable import torch +from monai.networks.utils import replace_module_temp from monai.utils.module import optional_import trange, has_trange = optional_import("tqdm", name="trange") @@ -25,16 +25,6 @@ __all__ = ["VanillaGrad", "SmoothGrad", "GuidedBackpropGrad", "GuidedBackpropSmoothGrad"] -def replace_module(parent: torch.nn.Module, name: str, new_module: torch.nn.Module) -> None: - idx = name.find(".") - if idx == -1: - setattr(parent, name, new_module) - else: - parent = getattr(parent, name[:idx]) - name = name[idx + 1 :] - replace_module(parent, name, new_module) - - class _AutoGradReLU(torch.autograd.Function): @staticmethod def forward(ctx, x): @@ -130,34 +120,13 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> return total_gradients / self.n_samples -@contextmanager -def replace_modules( - model: torch.nn.Module, name_to_replace: str = "relu", replace_with: type[torch.nn.Module] = _GradReLU -): - # replace - to_replace = [] - try: - for name, module in model.named_modules(): - if name_to_replace in name: - to_replace.append((name, module)) - - for name, _ in to_replace: - replace_module(model, name, replace_with()) - - yield - finally: - # regardless of success or not, revert model - for name, module in to_replace: - replace_module(model, name, module) - - class GuidedBackpropGrad(VanillaGrad): def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: - with replace_modules(self.model): + with replace_module_temp(self.model, "relu", _GradReLU(), strict_match=False): return super().__call__(x, index) class GuidedBackpropSmoothGrad(SmoothGrad): def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: - with replace_modules(self.model): + with replace_module_temp(self.model, "relu", _GradReLU(), strict_match=False): return super().__call__(x, index) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py new file mode 100644 index 0000000000..648202b3bf --- /dev/null +++ b/tests/test_replace_module.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet121 +from monai.networks.utils import replace_module, replace_module_temp +from tests.utils import TEST_DEVICES + +TESTS = [] +for device in TEST_DEVICES: + for match_device in (True, False): + # replace 1 + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", True, match_device, *device)) + # replace 1 (but not strict) + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", False, match_device, *device)) + # replace multiple + TESTS.append(("relu", False, match_device, *device)) + + +class TestReplaceModule(unittest.TestCase): + def setUp(self): + self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) + self.num_relus = self.get_num_modules(torch.nn.ReLU) + self.total = self.get_num_modules() + self.assertGreater(self.num_relus, 0) + + def get_num_modules(self, mod: Optional[torch.nn.Module] = None) -> int: + m = [m for _, m in self.net.named_modules()] + if mod is not None: + m = [_m for _m in m if isinstance(_m, mod)] + return len(m) + + def check_replaced_modules(self, name, match_device): + # total num modules should remain the same + self.assertEqual(self.total, self.get_num_modules()) + num_relus_mod = self.get_num_modules(torch.nn.ReLU) + num_softmax = self.get_num_modules(torch.nn.Softmax) + # list of returned modules should be as long as number of softmax + self.assertEqual(self.num_relus, num_relus_mod + num_softmax) + if name == "relu": + # at least 2 softmaxes + self.assertGreaterEqual(num_softmax, 2) + else: + # one softmax + self.assertEqual(num_softmax, 1) + if match_device: + self.assertEqual(len(list({i.device for i in self.net.parameters()})), 1) + + @parameterized.expand(TESTS) + def test_replace(self, name, strict_match, match_device, device): + self.net.to(device) + # replace module(s) + replaced = replace_module(self.net, name, torch.nn.Softmax(), strict_match, match_device) + self.check_replaced_modules(name, match_device) + # number of returned modules should equal number of softmax modules + self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax)) + # all replaced modules should be ReLU + for r in replaced: + self.assertIsInstance(r[1], torch.nn.ReLU) + # if a specfic module was named, check that the name matches exactly + if name == "features.denseblock1.denselayer1.layers.relu1": + self.assertEqual(replaced[0][0], name) + + @parameterized.expand(TESTS) + def test_replace_context_manager(self, name, strict_match, match_device, device): + self.net.to(device) + with replace_module_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device): + self.check_replaced_modules(name, match_device) + # Check that model was correctly reverted + self.assertEqual(self.get_num_modules(), self.total) + self.assertEqual(self.get_num_modules(torch.nn.ReLU), self.num_relus) + self.assertEqual(self.get_num_modules(torch.nn.Softmax), 0) + + def test_raises(self): + # name doesn't exist in module + with self.assertRaises(AttributeError): + replace_module(self.net, "non_existent_module", torch.nn.Softmax, strict_match=True) + with self.assertRaises(AttributeError): + with replace_module_temp(self.net, "non_existent_module", torch.nn.Softmax, strict_match=True): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py index 25dd530860..7655ca661e 100644 --- a/tests/test_vis_gradbased.py +++ b/tests/test_vis_gradbased.py @@ -37,7 +37,6 @@ class TestGradientClassActivationMap(unittest.TestCase): @parameterized.expand(TESTS) def test_shape(self, vis_type, model, shape, expected_shape): - print(shape) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() From fcb5de5c003389618fbe2b7a4ce68b59ff0ba88e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 16:25:10 +0100 Subject: [PATCH 08/12] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/visualize/gradient_based.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 98f7095ad0..63c31266fc 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -16,7 +16,7 @@ import torch -from monai.networks.utils import replace_module_temp +from monai.networks.utils import replace_modules_temp from monai.utils.module import optional_import trange, has_trange = optional_import("tqdm", name="trange") @@ -122,11 +122,11 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> class GuidedBackpropGrad(VanillaGrad): def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: - with replace_module_temp(self.model, "relu", _GradReLU(), strict_match=False): + with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): return super().__call__(x, index) class GuidedBackpropSmoothGrad(SmoothGrad): def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor: - with replace_module_temp(self.model, "relu", _GradReLU(), strict_match=False): + with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False): return super().__call__(x, index) From c5f790fc764d48c31c8ddd40c923928a7ffe3a68 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 May 2022 14:23:38 +0100 Subject: [PATCH 09/12] reuse modelwithhooks Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 4 +-- monai/visualize/gradient_based.py | 36 +++++++++++++----------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 16fb64cb46..93a63aba4b 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -139,10 +139,10 @@ def __call__(self, x, class_idx=None, retain_graph=False): self.score.sum().backward(retain_graph=retain_graph) for layer in self.target_layers: if layer not in self.gradients: - raise RuntimeError( + warnings.warn( f"Backward hook for {layer} is not triggered; `requires_grad` of {layer} should be `True`." ) - grad = tuple(self.gradients[layer] for layer in self.target_layers) + grad = tuple(self.gradients[layer] for layer in self.target_layers if layer in self.gradients) if train: self.model.train() return logits, acti, grad diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 63c31266fc..bf08fbdd88 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -18,6 +18,7 @@ from monai.networks.utils import replace_modules_temp from monai.utils.module import optional_import +from monai.visualize.class_activation_maps import ModelWithHooks trange, has_trange = optional_import("tqdm", name="trange") @@ -45,35 +46,38 @@ def backward(ctx, grad_output): class _GradReLU(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: - out: torch.Tensor = _AutoGradReLU().apply(x) + out: torch.Tensor = _AutoGradReLU.apply(x) return out def backward(self, x: torch.Tensor) -> torch.Tensor: - out: torch.Tensor = _AutoGradReLU().backward(x) + out: torch.Tensor = _AutoGradReLU.backward(x) return out class VanillaGrad: def __init__(self, model: torch.nn.Module) -> None: - self.model = model + if not isinstance(model, ModelWithHooks): # Convert to model with hooks if necessary + self._model = ModelWithHooks(model, target_layer_names="no_layer", register_backward=True) + else: + self._model = model + + @property + def model(self): + return self._model.model + + @model.setter + def model(self, m): + if not isinstance(m, ModelWithHooks): # regular model as ModelWithHooks + self._model.model = m + else: + self._model = m # replace the ModelWithHooks - def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None) -> torch.Tensor: + def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True) -> torch.Tensor: if x.shape[0] != 1: raise ValueError("expect batch size of 1") x.requires_grad = True - output: torch.Tensor = self.model(x) - - if index is None: - index = output.argmax().detach() - - num_classes = output.shape[-1] - one_hot = torch.zeros((1, num_classes), dtype=torch.float32, device=x.device) - one_hot[0][index] = 1 - one_hot.requires_grad = True - one_hot = torch.sum(one_hot * output) - - one_hot.backward(retain_graph=True) + self._model(x, class_idx=index, retain_graph=retain_graph) grad: torch.Tensor = x.grad.detach() return grad From 92c300b3228c83338cc497685b88bc03f0b7cbcd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 May 2022 14:44:29 +0100 Subject: [PATCH 10/12] fixes unit test Signed-off-by: Wenqi Li --- tests/test_vis_gradcam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py index acca06d405..755f4d49ae 100644 --- a/tests/test_vis_gradcam.py +++ b/tests/test_vis_gradcam.py @@ -85,7 +85,7 @@ def test_ill(self): x.requires_grad = False cam = GradCAM(nn_module=model, target_layers="class_layers.relu") image = torch.rand((2, 1, 48, 64)) - with self.assertRaises(RuntimeError): + with self.assertRaises(IndexError): cam(x=image) From ec3ba0ee08dd26c6423890a9a1366e5289b2f870 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 May 2022 15:38:23 +0100 Subject: [PATCH 11/12] support of empty target Signed-off-by: Wenqi Li --- monai/visualize/class_activation_maps.py | 2 +- monai/visualize/gradient_based.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py index 93a63aba4b..ba1f5d2589 100644 --- a/monai/visualize/class_activation_maps.py +++ b/monai/visualize/class_activation_maps.py @@ -89,7 +89,7 @@ def __init__( mod.register_backward_hook(self.backward_hook(name)) if self.register_forward: mod.register_forward_hook(self.forward_hook(name)) - if len(_registered) != len(self.target_layers): + if self.target_layers and (len(_registered) != len(self.target_layers)): warnings.warn(f"Not all target_layers exist in the network module: targets: {self.target_layers}.") def backward_hook(self, name): diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index bf08fbdd88..0ba68ed542 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -57,7 +57,7 @@ def backward(self, x: torch.Tensor) -> torch.Tensor: class VanillaGrad: def __init__(self, model: torch.nn.Module) -> None: if not isinstance(model, ModelWithHooks): # Convert to model with hooks if necessary - self._model = ModelWithHooks(model, target_layer_names="no_layer", register_backward=True) + self._model = ModelWithHooks(model, target_layer_names=(), register_backward=True) else: self._model = model From 6cac01f08402bc697e17324a23f8e3d02b65e0a8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 11 May 2022 16:43:35 +0100 Subject: [PATCH 12/12] remove module.backward, add docstring Signed-off-by: Wenqi Li --- monai/visualize/gradient_based.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 0ba68ed542..32b8110b6d 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -49,10 +49,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = _AutoGradReLU.apply(x) return out - def backward(self, x: torch.Tensor) -> torch.Tensor: - out: torch.Tensor = _AutoGradReLU.backward(x) - return out - class VanillaGrad: def __init__(self, model: torch.nn.Module) -> None: @@ -86,6 +82,11 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> class SmoothGrad(VanillaGrad): + """ + See also: + - Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825 + """ + def __init__( self, model: torch.nn.Module,