diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 83429a2837..e4cdfc6f9b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1111,3 +1111,47 @@ def replace_modules_temp( # revert for name, module in replaced: _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device) + + +def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): + """ + A utilty function to help freeze specific layers. + + Args: + model: a source PyTorch model to freeze layer. + freeze_vars: a regular expression to match the `model` variable names, + so that their `requires_grad` will set to `False`. + exclude_vars: a regular expression to match the `model` variable names, + except for matched variable names, other `requires_grad` will set to `False`. + + Raises: + ValueError: when freeze_vars and exclude_vars are both specified. + + """ + if freeze_vars is not None and exclude_vars is not None: + raise ValueError("Incompatible values: freeze_vars and exclude_vars are both specified.") + src_dict = get_state_dict(model) + + frozen_keys = list() + if freeze_vars is not None: + to_freeze = {s_key for s_key in src_dict if freeze_vars and re.compile(freeze_vars).search(s_key)} + for name, param in model.named_parameters(): + if name in to_freeze: + param.requires_grad = False + frozen_keys.append(name) + elif not param.requires_grad: + param.requires_grad = True + warnings.warn( + f"The freeze_vars does not include {param}, but requires_grad is False, change it to True." + ) + if exclude_vars is not None: + to_exclude = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)} + for name, param in model.named_parameters(): + if name not in to_exclude: + param.requires_grad = False + frozen_keys.append(name) + elif not param.requires_grad: + param.requires_grad = True + warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") + + logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py new file mode 100644 index 0000000000..29594ed98a --- /dev/null +++ b/tests/test_freeze_layers.py @@ -0,0 +1,61 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.utils import freeze_layers +from monai.utils import set_determinism +from tests.test_copy_model_state import _TestModelOne, _TestModelTwo + +TEST_CASES = [] +__devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) +for _x in __devices: + TEST_CASES.append(_x) + + +class TestModuleState(unittest.TestCase): + def tearDown(self): + set_determinism(None) + + @parameterized.expand(TEST_CASES) + def test_freeze_vars(self, device): + set_determinism(0) + model = _TestModelOne(10, 20, 3) + model.to(device) + freeze_layers(model, "class") + + for name, param in model.named_parameters(): + if "class_layer" in name: + self.assertEqual(param.requires_grad, False) + else: + self.assertEqual(param.requires_grad, True) + + @parameterized.expand(TEST_CASES) + def test_exclude_vars(self, device): + set_determinism(0) + model = _TestModelTwo(10, 20, 10, 4) + model.to(device) + freeze_layers(model, exclude_vars="class") + + for name, param in model.named_parameters(): + if "class_layer" in name: + self.assertEqual(param.requires_grad, True) + else: + self.assertEqual(param.requires_grad, False) + + +if __name__ == "__main__": + unittest.main()