Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ def _load_state_dict(model, model_url, progress):
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
pattern = re.compile(
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
new_key = res.group(1) + ".layers" + res.group(2) + res.group(3)
state_dict[new_key] = state_dict[key]
del state_dict[key]

Expand Down
41 changes: 35 additions & 6 deletions tests/test_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,25 @@
# limitations under the License.

import unittest
from typing import TYPE_CHECKING
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import densenet121, densenet169, densenet201, densenet264
from monai.utils import optional_import
from tests.utils import skip_if_quick, test_pretrained_networks, test_script_save

if TYPE_CHECKING:
import torchvision

has_torchvision = True
else:
torchvision, has_torchvision = optional_import("torchvision")


device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_1 = [ # 4-channel 3D, batch 2
Expand Down Expand Up @@ -50,27 +61,45 @@
TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2
densenet121,
{"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3},
(2, 2, 32, 64),
(2, 3),
(1, 2, 32, 64),
(1, 3),
]

TEST_PRETRAINED_2D_CASE_2 = [ # 4-channel 2D, batch 2
densenet121,
{"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 3},
(2, 2, 32, 64),
(2, 3),
{"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 2, "out_channels": 1},
(1, 2, 32, 64),
(1, 1),
]

TEST_PRETRAINED_2D_CASE_3 = [
densenet121,
{"pretrained": True, "progress": False, "spatial_dims": 2, "in_channels": 3, "out_channels": 1},
(1, 3, 32, 32),
]


class TestPretrainedDENSENET(unittest.TestCase):
@parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2])
@skip_if_quick
def test_121_3d_shape_pretrain(self, model, input_param, input_shape, expected_shape):
def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape):
net = test_pretrained_networks(model, input_param, device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

@parameterized.expand([TEST_PRETRAINED_2D_CASE_3])
@skipUnless(has_torchvision, "Requires `torchvision` package.")
def test_pretrain_consistency(self, model, input_param, input_shape):
example = torch.randn(input_shape).to(device)
net = test_pretrained_networks(model, input_param, device)
with eval_mode(net):
result = net.features.forward(example)
torchvision_net = torchvision.models.densenet121(pretrained=True).to(device)
with eval_mode(torchvision_net):
expected_result = torchvision_net.features.forward(example)
self.assertTrue(torch.all(result == expected_result))


class TestDENSENET(unittest.TestCase):
@parameterized.expand(TEST_CASES)
Expand Down