diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 76223dfaef..0543b11632 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -20,6 +20,8 @@ one_hot, pixelshuffle, predict_segmentation, + replace_modules, + replace_modules_temp, save_state, slice_channels, to_norm_affine, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f22be31524..34ea4f716e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,8 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union +from copy import deepcopy +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -41,6 +42,8 @@ "save_state", "convert_to_torchscript", "meshgrid_ij", + "replace_modules", + "replace_modules_temp", ] @@ -551,3 +554,102 @@ def meshgrid_ij(*tensors): if pytorch_after(1, 10): return torch.meshgrid(*tensors, indexing="ij") return torch.meshgrid(*tensors) + + +def _replace_modules( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + out: List[Tuple[str, torch.nn.Module]], + strict_match: bool = True, + match_device: bool = True, +) -> None: + """ + Helper function for :py:class:`monai.networks.utils.replace_modules`. + """ + if match_device: + devices = list({i.device for i in parent.parameters()}) + # if only one device for whole of model + if len(devices) == 1: + new_module.to(devices[0]) + idx = name.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = name[:idx] + parent = getattr(parent, parent_name) + name = name[idx + 1 :] + _out: List[Tuple[str, torch.nn.Module]] = [] + _replace_modules(parent, name, new_module, _out) + # prepend the parent name + out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] + # no "." in module name, do the actual replacing + else: + if strict_match: + old_module = getattr(parent, name) + setattr(parent, name, new_module) + out += [(name, old_module)] + else: + for mod_name, _ in parent.named_modules(): + if name in mod_name: + _replace_modules(parent, mod_name, deepcopy(new_module), out, strict_match=True) + + +def replace_modules( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +) -> List[Tuple[str, torch.nn.Module]]: + """ + Replace sub-module(s) in a parent module. + + The name of the module to be replace can be nested e.g., + `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are "." + in the module name), then this function will recursively call itself. + + Args: + parent: module that contains the module to be replaced + name: name of module to be replaced. Can include ".". + new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will + be deep copied if `strict_match == False` multiple instances are independent. + strict_match: if `True`, module name must `== name`. If false then + `name in named_modules()` will be used. `True` can be used to change just + one module, whereas `False` can be used to replace all modules with similar + name (e.g., `relu`). + match_device: if `True`, the device of the new module will match the model. Requires all + of `parent` to be on the same device. + + Returns: + List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module. + + Raises: + AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. + """ + out: List[Tuple[str, torch.nn.Module]] = [] + _replace_modules(parent, name, new_module, out, strict_match, match_device) + return out + + +@contextmanager +def replace_modules_temp( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +): + """ + Temporarily replace sub-module(s) in a parent module (context manager). + + See :py:class:`monai.networks.utils.replace_modules`. + """ + replaced: List[Tuple[str, torch.nn.Module]] = [] + try: + # replace + _replace_modules(parent, name, new_module, replaced, strict_match, match_device) + yield + finally: + # revert + for name, module in replaced: + _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py new file mode 100644 index 0000000000..4cb4443410 --- /dev/null +++ b/tests/test_replace_module.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional, Type + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet121 +from monai.networks.utils import replace_modules, replace_modules_temp +from tests.utils import TEST_DEVICES + +TESTS = [] +for device in TEST_DEVICES: + for match_device in (True, False): + # replace 1 + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", True, match_device, *device)) + # replace 1 (but not strict) + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", False, match_device, *device)) + # replace multiple + TESTS.append(("relu", False, match_device, *device)) + + +class TestReplaceModule(unittest.TestCase): + def setUp(self): + self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) + self.num_relus = self.get_num_modules(torch.nn.ReLU) + self.total = self.get_num_modules() + self.assertGreater(self.num_relus, 0) + + def get_num_modules(self, mod: Optional[Type[torch.nn.Module]] = None) -> int: + m = [m for _, m in self.net.named_modules()] + if mod is not None: + m = [_m for _m in m if isinstance(_m, mod)] + return len(m) + + def check_replaced_modules(self, name, match_device): + # total num modules should remain the same + self.assertEqual(self.total, self.get_num_modules()) + num_relus_mod = self.get_num_modules(torch.nn.ReLU) + num_softmax = self.get_num_modules(torch.nn.Softmax) + # list of returned modules should be as long as number of softmax + self.assertEqual(self.num_relus, num_relus_mod + num_softmax) + if name == "relu": + # at least 2 softmaxes + self.assertGreaterEqual(num_softmax, 2) + else: + # one softmax + self.assertEqual(num_softmax, 1) + if match_device: + self.assertEqual(len(list({i.device for i in self.net.parameters()})), 1) + + @parameterized.expand(TESTS) + def test_replace(self, name, strict_match, match_device, device): + self.net.to(device) + # replace module(s) + replaced = replace_modules(self.net, name, torch.nn.Softmax(), strict_match, match_device) + self.check_replaced_modules(name, match_device) + # number of returned modules should equal number of softmax modules + self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax)) + # all replaced modules should be ReLU + for r in replaced: + self.assertIsInstance(r[1], torch.nn.ReLU) + # if a specfic module was named, check that the name matches exactly + if name == "features.denseblock1.denselayer1.layers.relu1": + self.assertEqual(replaced[0][0], name) + + @parameterized.expand(TESTS) + def test_replace_context_manager(self, name, strict_match, match_device, device): + self.net.to(device) + with replace_modules_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device): + self.check_replaced_modules(name, match_device) + # Check that model was correctly reverted + self.assertEqual(self.get_num_modules(), self.total) + self.assertEqual(self.get_num_modules(torch.nn.ReLU), self.num_relus) + self.assertEqual(self.get_num_modules(torch.nn.Softmax), 0) + + def test_raises(self): + # name doesn't exist in module + with self.assertRaises(AttributeError): + replace_modules(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True) + with self.assertRaises(AttributeError): + with replace_modules_temp(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True): + pass + + +if __name__ == "__main__": + unittest.main()