Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions monai/networks/blocks/squeeze_and_excitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
r: int = 2,
acti_type_1: Union[Tuple[str, Dict], str] = ("relu", {"inplace": True}),
acti_type_2: Union[Tuple[str, Dict], str] = "sigmoid",
add_residual: bool = False,
) -> None:
"""
Args:
Expand All @@ -51,6 +52,8 @@ def __init__(
"""
super(ChannelSELayer, self).__init__()

self.add_residual = add_residual

pool_type = Pool[Pool.ADAPTIVEAVG, spatial_dims]
self.avg_pool = pool_type(1) # spatial size (1, 1, ...)

Expand All @@ -74,8 +77,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
b, c = x.shape[:2]
y: torch.Tensor = self.avg_pool(x).view(b, c)
y = self.fc(y).view([b, c] + [1] * (x.ndimension() - 2))
return x * y
y = self.fc(y).view([b, c] + [1] * (x.ndim - 2))
result = x * y

# Residual connection is moved here instead of providing an override of forward in ResidualSELayer since
# Torchscript has an issue with using super().
if self.add_residual:
result += x

return result


class ResidualSELayer(ChannelSELayer):
Expand All @@ -85,7 +95,6 @@ class ResidualSELayer(ChannelSELayer):
--+-- SE --o--
| |
+--------+

"""

def __init__(
Expand All @@ -105,21 +114,17 @@ def __init__(
acti_type_2: defaults to "relu".

See also:

:py:class:`monai.networks.blocks.ChannelSELayer`

"""
super().__init__(
spatial_dims=spatial_dims, in_channels=in_channels, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2
spatial_dims=spatial_dims,
in_channels=in_channels,
r=r,
acti_type_1=acti_type_1,
acti_type_2=acti_type_2,
add_residual=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
"""
return x + super().forward(x)


class SEBlock(nn.Module):
"""
Expand Down Expand Up @@ -196,28 +201,31 @@ def __init__(
spatial_dims=spatial_dims, in_channels=n_chns_3, r=r, acti_type_1=acti_type_1, acti_type_2=acti_type_2
)

self.project = project
if self.project is None and in_channels != n_chns_3:
if project is None and in_channels != n_chns_3:
self.project = Conv[Conv.CONV, spatial_dims](in_channels, n_chns_3, kernel_size=1)
elif project is None:
self.project = nn.Identity()
else:
self.project = project

self.act = None
if acti_type_final is not None:
act_final, act_final_args = split_args(acti_type_final)
self.act = Act[act_final](**act_final_args)
else:
self.act = nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
"""
residual = x if self.project is None else self.project(x)
residual = self.project(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.se_layer(x)
x += residual
if self.act is not None:
x = self.act(x)
x = self.act(x)
return x


Expand Down Expand Up @@ -358,7 +366,7 @@ def __init__(
conv_param_3 = {"strides": 1, "kernel_size": 1, "act": None, "norm": Norm.BATCH, "bias": False}
width = math.floor(planes * (base_width / 64)) * groups

super(SEResNeXtBottleneck, self).__init__(
super().__init__(
spatial_dims=spatial_dims,
in_channels=inplanes,
n_chns_1=width,
Expand Down
8 changes: 2 additions & 6 deletions monai/networks/nets/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -194,11 +194,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i

return decode

def forward(
self, x: torch.Tensor
) -> Union[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
]: # big tuple return necessary for VAE, which inherits
def forward(self, x: torch.Tensor) -> Any:
x = self.encode(x)
x = self.intermediate(x)
x = self.decode(x)
Expand Down
22 changes: 12 additions & 10 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.networks.layers.factories import Conv, Dropout, Norm, Pool


class _DenseLayer(nn.Sequential):
class _DenseLayer(nn.Module):
def __init__(
self, spatial_dims: int, in_channels: int, growth_rate: int, bn_size: int, dropout_prob: float
) -> None:
Expand All @@ -38,21 +38,23 @@ def __init__(
out_channels = bn_size * growth_rate
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]
dropout_type: Callable = Dropout[Dropout.DROPOUT, spatial_dims]

self.add_module("norm1", norm_type(in_channels))
self.add_module("relu1", nn.ReLU(inplace=True))
self.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))
self.layers = nn.Sequential()

self.add_module("norm2", norm_type(out_channels))
self.add_module("relu2", nn.ReLU(inplace=True))
self.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))
self.layers.add_module("norm1", norm_type(in_channels))
self.layers.add_module("relu1", nn.ReLU(inplace=True))
self.layers.add_module("conv1", conv_type(in_channels, out_channels, kernel_size=1, bias=False))

self.layers.add_module("norm2", norm_type(out_channels))
self.layers.add_module("relu2", nn.ReLU(inplace=True))
self.layers.add_module("conv2", conv_type(out_channels, growth_rate, kernel_size=3, padding=1, bias=False))

if dropout_prob > 0:
self.add_module("dropout", dropout_type(dropout_prob))
self.layers.add_module("dropout", dropout_type(dropout_prob))

def forward(self, x: torch.Tensor) -> torch.Tensor:
new_features = super(_DenseLayer, self).forward(x)
new_features = self.layers(x)
return torch.cat([x, new_features], 1)


Expand Down
11 changes: 9 additions & 2 deletions monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def check_kernel_stride(self):
kernels, strides = self.kernel_size, self.strides
error_msg = "length of kernel_size and strides should be the same, and no less than 3."
assert len(kernels) == len(strides) and len(kernels) >= 3, error_msg

for idx in range(len(kernels)):
kernel, stride = kernels[idx], strides[idx]
if not isinstance(kernel, int):
Expand All @@ -115,20 +116,26 @@ def check_deep_supr_num(self):
def forward(self, x):
out = self.input_block(x)
outputs = [out]

for downsample in self.downsamples:
out = downsample(out)
outputs.append(out)
outputs.insert(0, out)

out = self.bottleneck(out)
upsample_outs = []
for upsample, skip in zip(self.upsamples, reversed(outputs)):

for upsample, skip in zip(self.upsamples, outputs):
out = upsample(out, skip)
upsample_outs.append(out)

out = self.output_block(out)

if self.training and self.deep_supervision:
start_output_idx = len(upsample_outs) - 1 - self.deep_supr_num
upsample_outs = upsample_outs[start_output_idx:-1][::-1]
preds = [self.deep_supervision_heads[i](out) for i, out in enumerate(upsample_outs)]
return [out] + preds

return out

def get_input_block(self):
Expand Down
24 changes: 21 additions & 3 deletions monai/networks/nets/highresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(self.layers(x))


class ChannelPad(nn.Module):
def __init__(self, pad):
super().__init__()
self.pad = tuple(pad)

def forward(self, x):
return F.pad(x, self.pad)


class HighResBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -124,21 +133,26 @@ def __init__(
norm_type = Normalisation(norm_type)
acti_type = Activation(acti_type)

self.project, self.pad = None, None
self.project = None
self.pad = None

if in_channels != out_channels:
channel_matching = ChannelMatching(channel_matching)

if channel_matching == ChannelMatching.PROJECT:
self.project = conv_type(in_channels, out_channels, kernel_size=1)

if channel_matching == ChannelMatching.PAD:
if in_channels > out_channels:
raise ValueError('Incompatible values: channel_matching="pad" and in_channels > out_channels.')
pad_1 = (out_channels - in_channels) // 2
pad_2 = out_channels - in_channels - pad_1
pad = [0, 0] * spatial_dims + [pad_1, pad_2] + [0, 0]
self.pad = lambda input: F.pad(input, pad)
self.pad = ChannelPad(pad)

layers = nn.ModuleList()
_in_chns, _out_chns = in_channels, out_channels

for kernel_size in kernels:
layers.append(SUPPORTED_NORM[norm_type](spatial_dims)(_in_chns))
layers.append(SUPPORTED_ACTI[acti_type](inplace=True))
Expand All @@ -148,14 +162,18 @@ def __init__(
)
)
_in_chns = _out_chns

self.layers = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_conv: torch.Tensor = self.layers(x)

if self.project is not None:
return x_conv + torch.as_tensor(self.project(x))
return x_conv + torch.as_tensor(self.project(x)) # as_tensor used to get around mypy typing bug

if self.pad is not None:
return x_conv + torch.as_tensor(self.pad(x))

return x_conv + x


Expand Down
Loading