From 902479ab7d1f30a12251f646028d11aea7d64ad9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 9 Dec 2020 21:44:35 +0000 Subject: [PATCH 1/2] fixes params groups Signed-off-by: Wenqi Li --- monai/optimizers/utils.py | 28 +++++++++++++++++----------- tests/test_generate_param_groups.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index fdd8bc072f..b5950074de 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -13,7 +13,7 @@ import torch -from monai.utils import ensure_tuple +from monai.utils import ensure_tuple, ensure_tuple_rep def generate_param_groups( @@ -35,9 +35,11 @@ def generate_param_groups( match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions, can be "select" or "filter". lr_values: a list of LR values corresponding to the `layer_matches` functions. - include_others: whether to incude the rest layers as the last group, default to True. + include_others: whether to include the rest layers as the last group, default to True. - It's mainly used to set different init LR values for different network elements, for example:: + It's mainly used to set different LR values for different network elements, for example: + + .. code-block:: python net = Unet(dimensions=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1]) print(net) # print out network components to select expected items @@ -48,27 +50,31 @@ def generate_param_groups( match_types=["select", "filter"], lr_values=[1e-2, 1e-3], ) + # the groups will be a list of dictionaries: + # [{'params': , 'lr': 0.01}, + # {'params': , 'lr': 0.001}, + # {'params': }] optimizer = torch.optim.Adam(params, 1e-4) """ layer_matches = ensure_tuple(layer_matches) - match_types = ensure_tuple(match_types) - lr_values = ensure_tuple(lr_values) - if len(layer_matches) != len(lr_values) or len(layer_matches) != len(match_types): - raise ValueError("length of layer_match callable functions, match types and LR values should be the same.") + match_types = ensure_tuple_rep(match_types, len(layer_matches)) + lr_values = ensure_tuple_rep(lr_values, len(layer_matches)) params = list() _layers = list() for func, ty, lr in zip(layer_matches, match_types, lr_values): - if ty == "select": + if ty.lower() == "select": layer_params = func(network).parameters() - elif ty == "filter": + layer_ids = func(network).parameters() + elif ty.lower() == "filter": layer_params = filter(func, network.named_parameters()) + layer_ids = filter(func, network.named_parameters()) else: - raise ValueError(f"unsuppoted layer match type: {ty}.") + raise ValueError(f"unsupported layer match type: {ty}.") params.append({"params": layer_params, "lr": lr}) - _layers.extend(list(map(id, layer_params))) + _layers.extend(list(map(id, layer_ids))) if include_others: params.append({"params": filter(lambda p: id(p) not in _layers, network.parameters())}) diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py index 122b22dba4..2130234013 100644 --- a/tests/test_generate_param_groups.py +++ b/tests/test_generate_param_groups.py @@ -21,7 +21,7 @@ TEST_CASE_1 = [ { "layer_matches": [lambda x: x.model[-1]], - "match_types": ["select"], + "match_types": "select", "lr_values": [1], }, (1, 100), @@ -30,7 +30,7 @@ TEST_CASE_2 = [ { "layer_matches": [lambda x: x.model[-1], lambda x: x.model[-2], lambda x: x.model[-3]], - "match_types": ["select", "select", "select"], + "match_types": "select", "lr_values": [1, 2, 3], }, (1, 2, 3, 100), @@ -84,6 +84,30 @@ def test_lr_values(self, input_param, expected_values): for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)): torch.testing.assert_allclose(param_group["lr"], value) + n = [len(p["params"]) for p in params] + assert sum(n) == 26 or all(n), "should have either full model or non-empty subsets." + + def test_wrong(self): + """overlapped""" + device = "cuda" if torch.cuda.is_available() else "cpu" + net = Unet( + dimensions=3, + in_channels=1, + out_channels=3, + channels=(16, 32, 64), + strides=(2, 2), + num_res_units=1, + ).to(device) + + params = generate_param_groups( + network=net, + layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]], + match_types="select", + lr_values=0.1, + ) + with self.assertRaises(ValueError): + torch.optim.Adam(params, 100) + if __name__ == "__main__": unittest.main() From a5940c21953038f0bf11d2ba9d6636346d308933 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 10 Dec 2020 10:47:41 +0000 Subject: [PATCH 2/2] update based on comments Signed-off-by: Wenqi Li --- monai/optimizers/utils.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/monai/optimizers/utils.py b/monai/optimizers/utils.py index b5950074de..57c7528ba4 100644 --- a/monai/optimizers/utils.py +++ b/monai/optimizers/utils.py @@ -61,20 +61,30 @@ def generate_param_groups( match_types = ensure_tuple_rep(match_types, len(layer_matches)) lr_values = ensure_tuple_rep(lr_values, len(layer_matches)) + def _get_select(f): + def _select(): + return f(network).parameters() + + return _select + + def _get_filter(f): + def _filter(): + return filter(f, network.named_parameters()) + + return _filter + params = list() _layers = list() for func, ty, lr in zip(layer_matches, match_types, lr_values): if ty.lower() == "select": - layer_params = func(network).parameters() - layer_ids = func(network).parameters() + layer_params = _get_select(func) elif ty.lower() == "filter": - layer_params = filter(func, network.named_parameters()) - layer_ids = filter(func, network.named_parameters()) + layer_params = _get_filter(func) else: raise ValueError(f"unsupported layer match type: {ty}.") - params.append({"params": layer_params, "lr": lr}) - _layers.extend(list(map(id, layer_ids))) + params.append({"params": layer_params(), "lr": lr}) + _layers.extend(list(map(id, layer_params()))) if include_others: params.append({"params": filter(lambda p: id(p) not in _layers, network.parameters())})