From c115b775544fcb580741322a1cd2696b326d1d0a Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 14:59:51 +0800 Subject: [PATCH 1/6] add `freeze_layers` Signed-off-by: KumoLiu --- monai/networks/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 83429a2837..7bedc99c28 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1111,3 +1111,12 @@ def replace_modules_temp( # revert for name, module in replaced: _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device) + + +def freeze_layers(model: nn.Module, freeze_var=None): + src_dict = get_state_dict(model) + + to_freeze = {s_key for s_key in src_dict if freeze_var and re.compile(freeze_var).search(s_key)} + for name, param in model.named_parameters(): + if name in to_freeze: + param.requires_grad = False From bd87230c6eb1e400ee12150afcf4bf71ab26ce4f Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 15:42:28 +0800 Subject: [PATCH 2/6] add unittests Signed-off-by: KumoLiu --- monai/networks/utils.py | 11 +++++++++ tests/test_freeze_layers.py | 49 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 tests/test_freeze_layers.py diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 7bedc99c28..943d513661 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1114,9 +1114,20 @@ def replace_modules_temp( def freeze_layers(model: nn.Module, freeze_var=None): + """ + A utilty function to help freeze specific layers. + + Args: + model: a source PyTorch model to freeze layer. + freeze_vars: a regular expression to match the `model` variable names, + so that their `requires_grad` will set to `False`. + """ src_dict = get_state_dict(model) to_freeze = {s_key for s_key in src_dict if freeze_var and re.compile(freeze_var).search(s_key)} + frozen_keys = list() for name, param in model.named_parameters(): if name in to_freeze: param.requires_grad = False + frozen_keys.append(name) + logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py new file mode 100644 index 0000000000..ad21e3ddbe --- /dev/null +++ b/tests/test_freeze_layers.py @@ -0,0 +1,49 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.utils import freeze_layers +from monai.utils import set_determinism +from tests.test_copy_model_state import _TestModelOne + + +TEST_CASES = [] +__devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) +for _x in __devices: + TEST_CASES.append((_x)) + + +class TestModuleState(unittest.TestCase): + def tearDown(self): + set_determinism(None) + + @parameterized.expand(TEST_CASES) + def test_set_state(self, device): + set_determinism(0) + model = _TestModelOne(10, 20, 3) + model.to(device) + freeze_layers(model, "class") + + for name, param in model.named_parameters(): + if "class_layer" in name: + self.assertEqual(param.requires_grad, False) + else: + self.assertEqual(param.requires_grad, True) + + +if __name__ == "__main__": + unittest.main() From e4419c75924540576822f10dcc76408365dd7699 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:44:45 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_freeze_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py index ad21e3ddbe..2bbf2d0eb7 100644 --- a/tests/test_freeze_layers.py +++ b/tests/test_freeze_layers.py @@ -24,7 +24,7 @@ TEST_CASES = [] __devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) for _x in __devices: - TEST_CASES.append((_x)) + TEST_CASES.append(_x) class TestModuleState(unittest.TestCase): From ea2f3c2bc33141aca14d48a427c50ccbf20c6df1 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 16:26:48 +0800 Subject: [PATCH 4/6] add `exclude_vars` Signed-off-by: KumoLiu --- monai/networks/utils.py | 28 ++++++++++++++++++++++------ tests/test_freeze_layers.py | 18 +++++++++++++++--- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 943d513661..f6a677bc70 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1113,7 +1113,7 @@ def replace_modules_temp( _replace_modules(parent, name, module, [], strict_match=True, match_device=match_device) -def freeze_layers(model: nn.Module, freeze_var=None): +def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): """ A utilty function to help freeze specific layers. @@ -1121,13 +1121,29 @@ def freeze_layers(model: nn.Module, freeze_var=None): model: a source PyTorch model to freeze layer. freeze_vars: a regular expression to match the `model` variable names, so that their `requires_grad` will set to `False`. + exclude_vars: a regular expression to match the `model` variable names, + so that their `requires_grad` will set to `True`. + + Raises: + ValueError: when freeze_vars and exclude_vars are both specified. + """ + if freeze_vars is not None and exclude_vars is not None: + raise ValueError("Incompatible values: freeze_vars and exclude_vars are both specified.") src_dict = get_state_dict(model) - to_freeze = {s_key for s_key in src_dict if freeze_var and re.compile(freeze_var).search(s_key)} frozen_keys = list() - for name, param in model.named_parameters(): - if name in to_freeze: - param.requires_grad = False - frozen_keys.append(name) + if freeze_vars is not None: + to_freeze = {s_key for s_key in src_dict if freeze_vars and re.compile(freeze_vars).search(s_key)} + for name, param in model.named_parameters(): + if name in to_freeze: + param.requires_grad = False + frozen_keys.append(name) + if exclude_vars is not None: + to_exclude = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)} + for name, param in model.named_parameters(): + if name not in to_exclude: + param.requires_grad = False + frozen_keys.append(name) + logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py index 2bbf2d0eb7..29594ed98a 100644 --- a/tests/test_freeze_layers.py +++ b/tests/test_freeze_layers.py @@ -18,8 +18,7 @@ from monai.networks.utils import freeze_layers from monai.utils import set_determinism -from tests.test_copy_model_state import _TestModelOne - +from tests.test_copy_model_state import _TestModelOne, _TestModelTwo TEST_CASES = [] __devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) @@ -32,7 +31,7 @@ def tearDown(self): set_determinism(None) @parameterized.expand(TEST_CASES) - def test_set_state(self, device): + def test_freeze_vars(self, device): set_determinism(0) model = _TestModelOne(10, 20, 3) model.to(device) @@ -44,6 +43,19 @@ def test_set_state(self, device): else: self.assertEqual(param.requires_grad, True) + @parameterized.expand(TEST_CASES) + def test_exclude_vars(self, device): + set_determinism(0) + model = _TestModelTwo(10, 20, 10, 4) + model.to(device) + freeze_layers(model, exclude_vars="class") + + for name, param in model.named_parameters(): + if "class_layer" in name: + self.assertEqual(param.requires_grad, True) + else: + self.assertEqual(param.requires_grad, False) + if __name__ == "__main__": unittest.main() From d1ec80220f57860df93f306450e6a73203c7dcb5 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 18:40:42 +0800 Subject: [PATCH 5/6] address comments Signed-off-by: KumoLiu --- monai/networks/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index f6a677bc70..ac62b093c9 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1122,7 +1122,7 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): freeze_vars: a regular expression to match the `model` variable names, so that their `requires_grad` will set to `False`. exclude_vars: a regular expression to match the `model` variable names, - so that their `requires_grad` will set to `True`. + except for matched variable names, other `requires_grad` will set to `False`. Raises: ValueError: when freeze_vars and exclude_vars are both specified. @@ -1139,11 +1139,17 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): if name in to_freeze: param.requires_grad = False frozen_keys.append(name) + elif not param.requires_grad: + param.requires_grad = True + warnings.warn(f"The freeze_vars does not include {param}, but requires_grad is False, change it to True.") if exclude_vars is not None: to_exclude = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)} for name, param in model.named_parameters(): if name not in to_exclude: param.requires_grad = False frozen_keys.append(name) + elif not param.requires_grad: + param.requires_grad = True + warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.") logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.") From c6afba310b8572221f483b12439cfa0f0d7fbc94 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Tue, 12 Sep 2023 18:40:55 +0800 Subject: [PATCH 6/6] fix flake8 Signed-off-by: KumoLiu --- monai/networks/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ac62b093c9..e4cdfc6f9b 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -1141,7 +1141,9 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None): frozen_keys.append(name) elif not param.requires_grad: param.requires_grad = True - warnings.warn(f"The freeze_vars does not include {param}, but requires_grad is False, change it to True.") + warnings.warn( + f"The freeze_vars does not include {param}, but requires_grad is False, change it to True." + ) if exclude_vars is not None: to_exclude = {s_key for s_key in src_dict if exclude_vars and re.compile(exclude_vars).search(s_key)} for name, param in model.named_parameters():