Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 71 additions & 42 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class DenseNet(nn.Module):
bn_size: multiplicative factor for number of bottle neck layers.
(i.e. bn_size * k features in the bottleneck layer)
dropout_prob: dropout rate after each dense layer.
pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`.
In order to load weights correctly, Please ensure that the `block_config`
is consistent with the corresponding arch.
pretrained_arch: the arch name for pretrained weights.
progress: If True, displays a progress bar of the download to stderr.
"""

def __init__(
Expand All @@ -127,6 +132,9 @@ def __init__(
block_config: Sequence[int] = (6, 12, 24, 16),
bn_size: int = 4,
dropout_prob: float = 0.0,
pretrained: bool = False,
pretrained_arch: str = "densenet121",
progress: bool = True,
) -> None:

super(DenseNet, self).__init__()
Expand Down Expand Up @@ -190,43 +198,49 @@ def __init__(
elif isinstance(m, nn.Linear):
nn.init.constant_(torch.as_tensor(m.bias), 0)

if pretrained:
self._load_state_dict(pretrained_arch, progress)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.class_layers(x)
return x

def _load_state_dict(self, arch, progress):
"""
This function is used to load pretrained models.
Adapted from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model_urls = {
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
}
if arch in model_urls.keys():
model_url = model_urls[arch]
else:
raise ValueError(
"only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights."
)
pattern = re.compile(
r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

model_urls = {
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
}


def _load_state_dict(model, model_url, progress):
"""
This function is used to load pretrained models.
Adapted from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
pattern = re.compile(
r"^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)

state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + ".layers" + res.group(2) + res.group(3)
state_dict[new_key] = state_dict[key]
del state_dict[key]
state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + ".layers" + res.group(2) + res.group(3)
state_dict[new_key] = state_dict[key]
del state_dict[key]

model_dict = model.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
model.load_state_dict(model_dict)
model_dict = self.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
self.load_state_dict(model_dict)


def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> DenseNet:
Expand All @@ -235,10 +249,15 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs) -> De
from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
if pretrained:
arch = "densenet121"
_load_state_dict(model, model_urls[arch], progress)
model = DenseNet(
init_features=64,
growth_rate=32,
block_config=(6, 12, 24, 16),
pretrained=pretrained,
pretrained_arch="densenet121",
progress=progress,
**kwargs,
)
return model


Expand All @@ -248,10 +267,15 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs) -> De
from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
if pretrained:
arch = "densenet169"
_load_state_dict(model, model_urls[arch], progress)
model = DenseNet(
init_features=64,
growth_rate=32,
block_config=(6, 12, 32, 32),
pretrained=pretrained,
pretrained_arch="densenet169",
progress=progress,
**kwargs,
)
return model


Expand All @@ -261,10 +285,15 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs) -> De
from `PyTorch Hub 2D version
<https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py>`_
"""
model = DenseNet(init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
if pretrained:
arch = "densenet201"
_load_state_dict(model, model_urls[arch], progress)
model = DenseNet(
init_features=64,
growth_rate=32,
block_config=(6, 12, 48, 32),
pretrained=pretrained,
pretrained_arch="densenet201",
progress=progress,
**kwargs,
)
return model


Expand Down
153 changes: 84 additions & 69 deletions monai/networks/nets/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ class SENet(nn.Module):
- For SE-ResNeXt models: False
num_classes: number of outputs in `last_linear` layer.
for all models: 1000

pretrained: whether to load ImageNet pretrained weights when `spatial_dims == 2`.
In order to load weights correctly, Please ensure that the `block_config`
is consistent with the corresponding arch.
pretrained_arch: the arch name for pretrained weights.
progress: If True, displays a progress bar of the download to stderr.
"""

def __init__(
Expand All @@ -83,6 +87,9 @@ def __init__(
downsample_kernel_size: int = 3,
input_3x3: bool = True,
num_classes: int = 1000,
pretrained: bool = False,
pretrained_arch: str = "se_resnet50",
progress: bool = True,
) -> None:

super(SENet, self).__init__()
Expand Down Expand Up @@ -176,6 +183,64 @@ def __init__(
elif isinstance(m, nn.Linear):
nn.init.constant_(torch.as_tensor(m.bias), 0)

if pretrained:
self._load_state_dict(pretrained_arch, progress)

def _load_state_dict(self, arch, progress):
"""
This function is used to load pretrained models.
"""
model_urls = {
"senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth",
"se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth",
"se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth",
"se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth",
"se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth",
"se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth",
}
if arch in model_urls.keys():
model_url = model_urls[arch]
else:
raise ValueError(
"only 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 'se_resnext50_32x4d', \
and se_resnext101_32x4d are supported to load pretrained weights."
)

pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$")
pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$")
pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$")
pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$")
pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$")
pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$")

state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
new_key = None
if pattern_conv.match(key):
new_key = re.sub(pattern_conv, r"\1conv.\2", key)
elif pattern_bn.match(key):
new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key)
elif pattern_se.match(key):
state_dict[key] = state_dict[key].squeeze()
new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key)
elif pattern_se2.match(key):
state_dict[key] = state_dict[key].squeeze()
new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key)
elif pattern_down_conv.match(key):
new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key)
elif pattern_down_bn.match(key):
new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key)
if new_key:
state_dict[new_key] = state_dict[key]
del state_dict[key]

model_dict = self.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
self.load_state_dict(model_dict)

def _make_layer(
self,
block: Type[Union[SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck]],
Expand Down Expand Up @@ -248,56 +313,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


model_urls = {
"senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth",
"se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth",
"se_resnet101": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth",
"se_resnet152": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth",
"se_resnext50_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth",
"se_resnext101_32x4d": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth",
}


def _load_state_dict(model, model_url, progress):
"""
This function is used to load pretrained models.
"""
pattern_conv = re.compile(r"^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$")
pattern_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$")
pattern_se = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$")
pattern_se2 = re.compile(r"^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$")
pattern_down_conv = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$")
pattern_down_bn = re.compile(r"^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$")

state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
new_key = None
if pattern_conv.match(key):
new_key = re.sub(pattern_conv, r"\1conv.\2", key)
elif pattern_bn.match(key):
new_key = re.sub(pattern_bn, r"\1conv\2adn.N.\3", key)
elif pattern_se.match(key):
state_dict[key] = state_dict[key].squeeze()
new_key = re.sub(pattern_se, r"\1se_layer.fc.0.\2", key)
elif pattern_se2.match(key):
state_dict[key] = state_dict[key].squeeze()
new_key = re.sub(pattern_se2, r"\1se_layer.fc.2.\2", key)
elif pattern_down_conv.match(key):
new_key = re.sub(pattern_down_conv, r"\1project.conv.\2", key)
elif pattern_down_bn.match(key):
new_key = re.sub(pattern_down_bn, r"\1project.adn.N.\2", key)
if new_key:
state_dict[new_key] = state_dict[key]
del state_dict[key]

model_dict = model.state_dict()
state_dict = {
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
}
model_dict.update(state_dict)
model.load_state_dict(model_dict)


def senet154(
spatial_dims: int,
in_channels: int,
Expand All @@ -320,10 +335,10 @@ def senet154(
dropout_prob=0.2,
dropout_dim=1,
num_classes=num_classes,
pretrained=pretrained,
pretrained_arch="senet154",
progress=progress,
)
if pretrained:
arch = "senet154"
_load_state_dict(model, model_urls[arch], progress)
return model


Expand All @@ -347,10 +362,10 @@ def se_resnet50(
input_3x3=False,
downsample_kernel_size=1,
num_classes=num_classes,
pretrained=pretrained,
pretrained_arch="se_resnet50",
progress=progress,
)
if pretrained:
arch = "se_resnet50"
_load_state_dict(model, model_urls[arch], progress)
return model


Expand All @@ -375,10 +390,10 @@ def se_resnet101(
input_3x3=False,
downsample_kernel_size=1,
num_classes=num_classes,
pretrained=pretrained,
pretrained_arch="se_resnet101",
progress=progress,
)
if pretrained:
arch = "se_resnet101"
_load_state_dict(model, model_urls[arch], progress)
return model


Expand All @@ -403,10 +418,10 @@ def se_resnet152(
input_3x3=False,
downsample_kernel_size=1,
num_classes=num_classes,
pretrained=pretrained,
pretrained_arch="se_resnet152",
progress=progress,
)
if pretrained:
arch = "se_resnet152"
_load_state_dict(model, model_urls[arch], progress)
return model


Expand All @@ -430,10 +445,10 @@ def se_resnext50_32x4d(
input_3x3=False,
downsample_kernel_size=1,
num_classes=num_classes,
pretrained=pretrained,
pretrained_arch="se_resnext50_32x4d",
progress=progress,
)
if pretrained:
arch = "se_resnext50_32x4d"
_load_state_dict(model, model_urls[arch], progress)
return model


Expand All @@ -457,8 +472,8 @@ def se_resnext101_32x4d(
input_3x3=False,
downsample_kernel_size=1,
num_classes=num_classes,
pretrained=pretrained,
pretrained_arch="se_resnext101_32x4d",
progress=progress,
)
if pretrained:
arch = "se_resnext101_32x4d"
_load_state_dict(model, model_urls[arch], progress)
return model
Loading