Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions monai/optimizers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from monai.utils import ensure_tuple
from monai.utils import ensure_tuple, ensure_tuple_rep


def generate_param_groups(
Expand All @@ -35,9 +35,11 @@ def generate_param_groups(
match_types: a list of tags to identify the matching type corresponding to the `layer_matches` functions,
can be "select" or "filter".
lr_values: a list of LR values corresponding to the `layer_matches` functions.
include_others: whether to incude the rest layers as the last group, default to True.
include_others: whether to include the rest layers as the last group, default to True.

It's mainly used to set different init LR values for different network elements, for example::
It's mainly used to set different LR values for different network elements, for example:

.. code-block:: python

net = Unet(dimensions=3, in_channels=1, out_channels=3, channels=[2, 2, 2], strides=[1, 1, 1])
print(net) # print out network components to select expected items
Expand All @@ -48,27 +50,41 @@ def generate_param_groups(
match_types=["select", "filter"],
lr_values=[1e-2, 1e-3],
)
# the groups will be a list of dictionaries:
# [{'params': <generator object Module.parameters at 0x7f9090a70bf8>, 'lr': 0.01},
# {'params': <filter object at 0x7f9088fd0dd8>, 'lr': 0.001},
# {'params': <filter object at 0x7f9088fd0da0>}]
optimizer = torch.optim.Adam(params, 1e-4)

"""
layer_matches = ensure_tuple(layer_matches)
match_types = ensure_tuple(match_types)
lr_values = ensure_tuple(lr_values)
if len(layer_matches) != len(lr_values) or len(layer_matches) != len(match_types):
raise ValueError("length of layer_match callable functions, match types and LR values should be the same.")
match_types = ensure_tuple_rep(match_types, len(layer_matches))
lr_values = ensure_tuple_rep(lr_values, len(layer_matches))

def _get_select(f):
def _select():
return f(network).parameters()

return _select

def _get_filter(f):
def _filter():
return filter(f, network.named_parameters())

return _filter

params = list()
_layers = list()
for func, ty, lr in zip(layer_matches, match_types, lr_values):
if ty == "select":
layer_params = func(network).parameters()
elif ty == "filter":
layer_params = filter(func, network.named_parameters())
if ty.lower() == "select":
layer_params = _get_select(func)
elif ty.lower() == "filter":
layer_params = _get_filter(func)
else:
raise ValueError(f"unsuppoted layer match type: {ty}.")
raise ValueError(f"unsupported layer match type: {ty}.")

params.append({"params": layer_params, "lr": lr})
_layers.extend(list(map(id, layer_params)))
params.append({"params": layer_params(), "lr": lr})
_layers.extend(list(map(id, layer_params())))

if include_others:
params.append({"params": filter(lambda p: id(p) not in _layers, network.parameters())})
Expand Down
28 changes: 26 additions & 2 deletions tests/test_generate_param_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TEST_CASE_1 = [
{
"layer_matches": [lambda x: x.model[-1]],
"match_types": ["select"],
"match_types": "select",
"lr_values": [1],
},
(1, 100),
Expand All @@ -30,7 +30,7 @@
TEST_CASE_2 = [
{
"layer_matches": [lambda x: x.model[-1], lambda x: x.model[-2], lambda x: x.model[-3]],
"match_types": ["select", "select", "select"],
"match_types": "select",
"lr_values": [1, 2, 3],
},
(1, 2, 3, 100),
Expand Down Expand Up @@ -84,6 +84,30 @@ def test_lr_values(self, input_param, expected_values):
for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)):
torch.testing.assert_allclose(param_group["lr"], value)

n = [len(p["params"]) for p in params]
assert sum(n) == 26 or all(n), "should have either full model or non-empty subsets."

def test_wrong(self):
"""overlapped"""
device = "cuda" if torch.cuda.is_available() else "cpu"
net = Unet(
dimensions=3,
in_channels=1,
out_channels=3,
channels=(16, 32, 64),
strides=(2, 2),
num_res_units=1,
).to(device)

params = generate_param_groups(
network=net,
layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]],
match_types="select",
lr_values=0.1,
)
with self.assertRaises(ValueError):
torch.optim.Adam(params, 100)


if __name__ == "__main__":
unittest.main()