Skip to content
38 changes: 38 additions & 0 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
Expand All @@ -84,6 +85,7 @@ def __init__(
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.

Examples::

Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
use_v2=use_v2,
)

self.encoder1 = UnetrBasicBlock(
Expand Down Expand Up @@ -921,6 +924,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
Expand All @@ -942,6 +946,7 @@ def __init__(
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
"""

super().__init__()
Expand All @@ -959,10 +964,16 @@ def __init__(
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.use_v2 = use_v2
self.layers1 = nn.ModuleList()
self.layers2 = nn.ModuleList()
self.layers3 = nn.ModuleList()
self.layers4 = nn.ModuleList()
if self.use_v2:
self.layers1c = nn.ModuleList()
self.layers2c = nn.ModuleList()
self.layers3c = nn.ModuleList()
self.layers4c = nn.ModuleList()
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
for i_layer in range(self.num_layers):
layer = BasicLayer(
Expand All @@ -987,6 +998,25 @@ def __init__(
self.layers3.append(layer)
elif i_layer == 3:
self.layers4.append(layer)
if self.use_v2:
layerc = UnetrBasicBlock(
spatial_dims=3,
in_channels=embed_dim * 2**i_layer,
out_channels=embed_dim * 2**i_layer,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
if i_layer == 0:
self.layers1c.append(layerc)
elif i_layer == 1:
self.layers2c.append(layerc)
elif i_layer == 2:
self.layers3c.append(layerc)
elif i_layer == 3:
self.layers4c.append(layerc)

self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

def proj_out(self, x, normalize=False):
Expand All @@ -1008,12 +1038,20 @@ def forward(self, x, normalize=True):
x0 = self.patch_embed(x)
x0 = self.pos_drop(x0)
x0_out = self.proj_out(x0, normalize)
if self.use_v2:
x0 = self.layers1c[0](x0.contiguous())
x1 = self.layers1[0](x0.contiguous())
x1_out = self.proj_out(x1, normalize)
if self.use_v2:
x1 = self.layers2c[0](x1.contiguous())
x2 = self.layers2[0](x1.contiguous())
x2_out = self.proj_out(x2, normalize)
if self.use_v2:
x2 = self.layers3c[0](x2.contiguous())
x3 = self.layers3[0](x2.contiguous())
x3_out = self.proj_out(x3, normalize)
if self.use_v2:
x3 = self.layers4c[0](x3.contiguous())
x4 = self.layers4[0](x3.contiguous())
x4_out = self.proj_out(x4, normalize)
return [x0_out, x1_out, x2_out, x3_out, x4_out]