-
Notifications
You must be signed in to change notification settings - Fork 1.4k
ResNet implementation: set bias=False for downsample-B #5477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Geevarghese George <thatgeeman@users.noreply.github.com>
|
We hadn't implemented downloading pre-trained weights so that part isn't affected by this, but other saved instances of this network like here won't load correctly. To maintain backwards compatibility we should add a |
|
Yes, that makes sense, I'll make the additions in the coming days. Making this PR a draft for now. |
Signed-off-by: Geevarghese George <thatgeeman@users.noreply.github.com>
|
Made the requisite changes to accept an additional kwarg as discussed. How does this look @ericspod? |
|
Looks good to me, we should add a test case to |
|
Agreed! What would the test case look like: would it be more of a sanity check to see if the expected shapes are returned with |
|
Yes I don't think much more is needed than that, if we had a version of an existing unit test with |
Signed-off-by: Geevarghese George <thatgeeman@users.noreply.github.com>
745b5af to
e5dafbe
Compare
|
Just added to the test_resnet case as discussed. ^ |
wyli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it looks good to me.
|
/build |
|
/build |
1 similar comment
|
/build |
…#5477) Signed-off-by: Geevarghese George <thatgeeman@users.noreply.github.com> ### Description This is a simple fix following Project-MONAI#5465. The downsampling layer is not expected to have a bias term. The previous implementation did not explicitly set `bias=False` and defaulted to PyTorch Conv3D/2D where `bias=True`. With this change, the correct number (62 for resnet18) of parameter groups are returned: ```python from torchvision import models from monai.networks import nets d2net_torch = models.resnet18() d2net_monai = nets.resnet18(spatial_dims=2) d3net_monai = nets.resnet18(spatial_dims=3) len(list(d2net_torch.parameters())), len(list(d2net_monai.parameters())), len(list(d3net_monai.parameters())) # 62, 62, 62 # before: 62 65 65 ``` Other deeper 2D ResNet architectures are also comparable to the PyTorch implementation; the `pretrained`/`weights` parameter can be allowed for these networks. Currently, it raises a `NotImplemetedError` with `pretrained=True` even for 2D ResNets. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Geevarghese George <thatgeeman@users.noreply.github.com> Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com>
|
Currently, the When manually loading MedicalNet weights, the downsample bias terms raise errors as they are not present in the loaded weights. It is also not possible to remove |
|
Hi @acerdur @wyli Details: AFAIU the error you are facing comes from the shortcut layer. To correctly load the pretrained weights of MedNet, you should initialize the model with the correct achitecture with from monai.networks import nets
# MONAI ResNet18
net = nets.resnet18(pretrained=False, spatial_dims=3, n_input_channels=1, num_classes=2, shortcut_type='A')
wt_path = 'resnet_18.pth' # path to weights from Google Drive of Tencent
pretrained_weights = torch.load(f=wt_path , map_location=device)
# match the keys
weights = OrderedDict()
for k, v in pretrained_weights['state_dict'].items():
weights.update({k.replace('module.', ''): v})
net.load_state_dict(weights, strict=False) # _IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])The only pair of incompatible keys are for the last linear layer, which is expected for finetuning, and not provided/inferable in the MedNet weights. Related: #6811 |
Signed-off-by: Geevarghese George thatgeeman@users.noreply.github.com
Description
This is a simple fix following #5465.
The downsampling layer is not expected to have a bias term. The previous implementation did not explicitly set
bias=Falseand defaulted to PyTorch Conv3D/2D wherebias=True. With this change, the correct number (62 for resnet18) of parameter groups are returned:Other deeper 2D ResNet architectures are also comparable to the PyTorch implementation; the
pretrained/weightsparameter can be allowed for these networks. Currently, it raises aNotImplemetedErrorwithpretrained=Trueeven for 2D ResNets.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.