From b8658ff2a9979fdff3a45663bd072b18078c0699 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 13:29:20 +0100 Subject: [PATCH 1/4] replace modules Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/__init__.py | 2 + monai/networks/utils.py | 104 ++++++++++++++++++++++++++++++++++- tests/test_replace_module.py | 97 ++++++++++++++++++++++++++++++++ 3 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 tests/test_replace_module.py diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index 76223dfaef..b563a89d8c 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -20,6 +20,8 @@ one_hot, pixelshuffle, predict_segmentation, + replace_module, + replace_module_temp, save_state, slice_channels, to_norm_affine, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f22be31524..fb37c8bdbf 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -15,7 +15,8 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union +from copy import deepcopy +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -41,6 +42,8 @@ "save_state", "convert_to_torchscript", "meshgrid_ij", + "replace_module", + "replace_module_temp", ] @@ -551,3 +554,102 @@ def meshgrid_ij(*tensors): if pytorch_after(1, 10): return torch.meshgrid(*tensors, indexing="ij") return torch.meshgrid(*tensors) + + +def _replace_module( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + out: list, + strict_match: bool = True, + match_device: bool = True, +) -> None: + """ + Helper function for :py:class:`monai.networks.utils.replace_module`. + """ + if match_device: + devices = list({i.device for i in parent.parameters()}) + # if only one device for whole of model + if len(devices) == 1: + new_module.to(devices[0]) + idx = name.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = name[:idx] + parent = getattr(parent, parent_name) + name = name[idx + 1 :] + _out = [] + _replace_module(parent, name, new_module, _out) + # prepend the parent name + out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] + # no "." in module name, do the actual replacing + else: + if strict_match: + old_module = getattr(parent, name) + setattr(parent, name, new_module) + out += [(name, old_module)] + else: + for mod_name, _ in parent.named_modules(): + if name in mod_name: + _replace_module(parent, mod_name, deepcopy(new_module), out, strict_match=True) + + +def replace_module( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +) -> List[Tuple[str, torch.nn.Module]]: + """ + Replace sub-module(s) in a parent module. + + The name of the module to be replace can be nested e.g., + `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are "." + in the module name), then this function will recursively call itself. + + Args: + parent: module that contains the module to be replaced + name: name of module to be replaced. Can include ".". + new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will + be deep copied if `strict_match == False` multiple instances are independent. + strict_match: if `True`, module name must `== name`. If false then + `name in named_modules()` will be used. `True` can be used to change just + one module, whereas `False` can be used to replace all modules with similar + name (e.g., `relu`). + match_device: if `True`, the device of the new module will match the model. Requires all + of `parent` to be on the same device. + + Returns: + List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module. + + Raises: + AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. + """ + out = [] + _replace_module(parent, name, new_module, out, strict_match, match_device) + return out + + +@contextmanager +def replace_module_temp( + parent: torch.nn.Module, + name: str, + new_module: torch.nn.Module, + strict_match: bool = True, + match_device: bool = True, +) -> None: + """ + Temporarily replace sub-module(s) in a parent module (context manager). + + See :py:class:`monai.networks.utils.replace_module`. + """ + replaced = [] + try: + # replace + _replace_module(parent, name, new_module, replaced, strict_match, match_device) + yield + finally: + # revert + for name, module in replaced: + _replace_module(parent, name, module, [], strict_match=True, match_device=match_device) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py new file mode 100644 index 0000000000..648202b3bf --- /dev/null +++ b/tests/test_replace_module.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import torch +from parameterized import parameterized + +from monai.networks.nets import DenseNet121 +from monai.networks.utils import replace_module, replace_module_temp +from tests.utils import TEST_DEVICES + +TESTS = [] +for device in TEST_DEVICES: + for match_device in (True, False): + # replace 1 + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", True, match_device, *device)) + # replace 1 (but not strict) + TESTS.append(("features.denseblock1.denselayer1.layers.relu1", False, match_device, *device)) + # replace multiple + TESTS.append(("relu", False, match_device, *device)) + + +class TestReplaceModule(unittest.TestCase): + def setUp(self): + self.net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3) + self.num_relus = self.get_num_modules(torch.nn.ReLU) + self.total = self.get_num_modules() + self.assertGreater(self.num_relus, 0) + + def get_num_modules(self, mod: Optional[torch.nn.Module] = None) -> int: + m = [m for _, m in self.net.named_modules()] + if mod is not None: + m = [_m for _m in m if isinstance(_m, mod)] + return len(m) + + def check_replaced_modules(self, name, match_device): + # total num modules should remain the same + self.assertEqual(self.total, self.get_num_modules()) + num_relus_mod = self.get_num_modules(torch.nn.ReLU) + num_softmax = self.get_num_modules(torch.nn.Softmax) + # list of returned modules should be as long as number of softmax + self.assertEqual(self.num_relus, num_relus_mod + num_softmax) + if name == "relu": + # at least 2 softmaxes + self.assertGreaterEqual(num_softmax, 2) + else: + # one softmax + self.assertEqual(num_softmax, 1) + if match_device: + self.assertEqual(len(list({i.device for i in self.net.parameters()})), 1) + + @parameterized.expand(TESTS) + def test_replace(self, name, strict_match, match_device, device): + self.net.to(device) + # replace module(s) + replaced = replace_module(self.net, name, torch.nn.Softmax(), strict_match, match_device) + self.check_replaced_modules(name, match_device) + # number of returned modules should equal number of softmax modules + self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax)) + # all replaced modules should be ReLU + for r in replaced: + self.assertIsInstance(r[1], torch.nn.ReLU) + # if a specfic module was named, check that the name matches exactly + if name == "features.denseblock1.denselayer1.layers.relu1": + self.assertEqual(replaced[0][0], name) + + @parameterized.expand(TESTS) + def test_replace_context_manager(self, name, strict_match, match_device, device): + self.net.to(device) + with replace_module_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device): + self.check_replaced_modules(name, match_device) + # Check that model was correctly reverted + self.assertEqual(self.get_num_modules(), self.total) + self.assertEqual(self.get_num_modules(torch.nn.ReLU), self.num_relus) + self.assertEqual(self.get_num_modules(torch.nn.Softmax), 0) + + def test_raises(self): + # name doesn't exist in module + with self.assertRaises(AttributeError): + replace_module(self.net, "non_existent_module", torch.nn.Softmax, strict_match=True) + with self.assertRaises(AttributeError): + with replace_module_temp(self.net, "non_existent_module", torch.nn.Softmax, strict_match=True): + pass + + +if __name__ == "__main__": + unittest.main() From b07f538dcf89ebf57979c0eded10e6077f37bad8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 13:34:14 +0100 Subject: [PATCH 2/4] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_replace_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py index 648202b3bf..eb3d592d3a 100644 --- a/tests/test_replace_module.py +++ b/tests/test_replace_module.py @@ -37,7 +37,7 @@ def setUp(self): self.total = self.get_num_modules() self.assertGreater(self.num_relus, 0) - def get_num_modules(self, mod: Optional[torch.nn.Module] = None) -> int: + def get_num_modules(self, mod: Optional[type[torch.nn.Module]] = None) -> int: m = [m for _, m in self.net.named_modules()] if mod is not None: m = [_m for _m in m if isinstance(_m, mod)] @@ -87,9 +87,9 @@ def test_replace_context_manager(self, name, strict_match, match_device, device) def test_raises(self): # name doesn't exist in module with self.assertRaises(AttributeError): - replace_module(self.net, "non_existent_module", torch.nn.Softmax, strict_match=True) + replace_module(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True) with self.assertRaises(AttributeError): - with replace_module_temp(self.net, "non_existent_module", torch.nn.Softmax, strict_match=True): + with replace_module_temp(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True): pass From d20b56929887e8b3c5cbbbe1e5596ca24eb3a996 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 13:38:43 +0100 Subject: [PATCH 3/4] replace_module -> replace_modules Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/__init__.py | 4 ++-- monai/networks/utils.py | 24 ++++++++++++------------ tests/test_replace_module.py | 10 +++++----- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py index b563a89d8c..0543b11632 100644 --- a/monai/networks/__init__.py +++ b/monai/networks/__init__.py @@ -20,8 +20,8 @@ one_hot, pixelshuffle, predict_segmentation, - replace_module, - replace_module_temp, + replace_modules, + replace_modules_temp, save_state, slice_channels, to_norm_affine, diff --git a/monai/networks/utils.py b/monai/networks/utils.py index fb37c8bdbf..1dc71fa25b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -42,8 +42,8 @@ "save_state", "convert_to_torchscript", "meshgrid_ij", - "replace_module", - "replace_module_temp", + "replace_modules", + "replace_modules_temp", ] @@ -556,7 +556,7 @@ def meshgrid_ij(*tensors): return torch.meshgrid(*tensors) -def _replace_module( +def _replace_modules( parent: torch.nn.Module, name: str, new_module: torch.nn.Module, @@ -565,7 +565,7 @@ def _replace_module( match_device: bool = True, ) -> None: """ - Helper function for :py:class:`monai.networks.utils.replace_module`. + Helper function for :py:class:`monai.networks.utils.replace_modules`. """ if match_device: devices = list({i.device for i in parent.parameters()}) @@ -579,7 +579,7 @@ def _replace_module( parent = getattr(parent, parent_name) name = name[idx + 1 :] _out = [] - _replace_module(parent, name, new_module, _out) + _replace_modules(parent, name, new_module, _out) # prepend the parent name out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] # no "." in module name, do the actual replacing @@ -591,10 +591,10 @@ def _replace_module( else: for mod_name, _ in parent.named_modules(): if name in mod_name: - _replace_module(parent, mod_name, deepcopy(new_module), out, strict_match=True) + _replace_modules(parent, mod_name, deepcopy(new_module), out, strict_match=True) -def replace_module( +def replace_modules( parent: torch.nn.Module, name: str, new_module: torch.nn.Module, @@ -627,12 +627,12 @@ def replace_module( AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. """ out = [] - _replace_module(parent, name, new_module, out, strict_match, match_device) + _replace_modules(parent, name, new_module, out, strict_match, match_device) return out @contextmanager -def replace_module_temp( +def replace_modules_temp( parent: torch.nn.Module, name: str, new_module: torch.nn.Module, @@ -642,14 +642,14 @@ def replace_module_temp( """ Temporarily replace sub-module(s) in a parent module (context manager). - See :py:class:`monai.networks.utils.replace_module`. + See :py:class:`monai.networks.utils.replace_modules`. """ replaced = [] try: # replace - _replace_module(parent, name, new_module, replaced, strict_match, match_device) + _replace_modules(parent, name, new_module, replaced, strict_match, match_device) yield finally: # revert for name, module in replaced: - _replace_module(parent, name, module, [], strict_match=True, match_device=match_device) + _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py index eb3d592d3a..3a2be99ca7 100644 --- a/tests/test_replace_module.py +++ b/tests/test_replace_module.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.networks.nets import DenseNet121 -from monai.networks.utils import replace_module, replace_module_temp +from monai.networks.utils import replace_modules, replace_modules_temp from tests.utils import TEST_DEVICES TESTS = [] @@ -63,7 +63,7 @@ def check_replaced_modules(self, name, match_device): def test_replace(self, name, strict_match, match_device, device): self.net.to(device) # replace module(s) - replaced = replace_module(self.net, name, torch.nn.Softmax(), strict_match, match_device) + replaced = replace_modules(self.net, name, torch.nn.Softmax(), strict_match, match_device) self.check_replaced_modules(name, match_device) # number of returned modules should equal number of softmax modules self.assertEqual(len(replaced), self.get_num_modules(torch.nn.Softmax)) @@ -77,7 +77,7 @@ def test_replace(self, name, strict_match, match_device, device): @parameterized.expand(TESTS) def test_replace_context_manager(self, name, strict_match, match_device, device): self.net.to(device) - with replace_module_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device): + with replace_modules_temp(self.net, name, torch.nn.Softmax(), strict_match, match_device): self.check_replaced_modules(name, match_device) # Check that model was correctly reverted self.assertEqual(self.get_num_modules(), self.total) @@ -87,9 +87,9 @@ def test_replace_context_manager(self, name, strict_match, match_device, device) def test_raises(self): # name doesn't exist in module with self.assertRaises(AttributeError): - replace_module(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True) + replace_modules(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True) with self.assertRaises(AttributeError): - with replace_module_temp(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True): + with replace_modules_temp(self.net, "non_existent_module", torch.nn.Softmax(), strict_match=True): pass From ecbe88275681d7ac0adfb48e1e004b794470b19b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 9 May 2022 13:59:29 +0100 Subject: [PATCH 4/4] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/networks/utils.py | 10 +++++----- tests/test_replace_module.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 1dc71fa25b..34ea4f716e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -560,7 +560,7 @@ def _replace_modules( parent: torch.nn.Module, name: str, new_module: torch.nn.Module, - out: list, + out: List[Tuple[str, torch.nn.Module]], strict_match: bool = True, match_device: bool = True, ) -> None: @@ -578,7 +578,7 @@ def _replace_modules( parent_name = name[:idx] parent = getattr(parent, parent_name) name = name[idx + 1 :] - _out = [] + _out: List[Tuple[str, torch.nn.Module]] = [] _replace_modules(parent, name, new_module, _out) # prepend the parent name out += [(f"{parent_name}.{r[0]}", r[1]) for r in _out] @@ -626,7 +626,7 @@ def replace_modules( Raises: AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`. """ - out = [] + out: List[Tuple[str, torch.nn.Module]] = [] _replace_modules(parent, name, new_module, out, strict_match, match_device) return out @@ -638,13 +638,13 @@ def replace_modules_temp( new_module: torch.nn.Module, strict_match: bool = True, match_device: bool = True, -) -> None: +): """ Temporarily replace sub-module(s) in a parent module (context manager). See :py:class:`monai.networks.utils.replace_modules`. """ - replaced = [] + replaced: List[Tuple[str, torch.nn.Module]] = [] try: # replace _replace_modules(parent, name, new_module, replaced, strict_match, match_device) diff --git a/tests/test_replace_module.py b/tests/test_replace_module.py index 3a2be99ca7..4cb4443410 100644 --- a/tests/test_replace_module.py +++ b/tests/test_replace_module.py @@ -10,7 +10,7 @@ # limitations under the License. import unittest -from typing import Optional +from typing import Optional, Type import torch from parameterized import parameterized @@ -37,7 +37,7 @@ def setUp(self): self.total = self.get_num_modules() self.assertGreater(self.num_relus, 0) - def get_num_modules(self, mod: Optional[type[torch.nn.Module]] = None) -> int: + def get_num_modules(self, mod: Optional[Type[torch.nn.Module]] = None) -> int: m = [m for _, m in self.net.named_modules()] if mod is not None: m = [_m for _m in m if isinstance(_m, mod)]