Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions monai/networks/blocks/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,21 @@ def __init__(
thus if size is defined, `scale_factor` will not be used.
Defaults to None.
mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
pre_conv: a conv block applied before upsampling. Defaults to None.
pre_conv: a conv block applied before upsampling. Defaults to "default".
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
Only used in the "nontrainable" or "pixelshuffle" mode.
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
Only used in the "nontrainable" mode.
If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
The interpolation mode. Defaults to ``"linear"``.
See also: https://pytorch.org/docs/stable/nn.html#upsample
align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
Only used in the nontrainable mode.
Only used in the "nontrainable" mode.
bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.
apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
Only used in the pixelshuffle mode.
Only used in the "pixelshuffle" mode.
"""
super().__init__()
scale_factor_ = ensure_tuple_rep(scale_factor, dimensions)
Expand Down Expand Up @@ -104,6 +104,10 @@ def __init__(
)
elif pre_conv is not None and pre_conv != "default":
self.add_module("preconv", pre_conv) # type: ignore
elif pre_conv is None and (out_channels != in_channels):
raise ValueError(
"in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels."
)

interp_mode = InterpolateMode(interp_mode)
linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
Expand Down
29 changes: 25 additions & 4 deletions monai/networks/nets/basic_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Union
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -92,6 +92,9 @@ def __init__(
norm: Union[str, tuple],
dropout: Union[float, tuple] = 0.0,
upsample: str = "deconv",
pre_conv: Optional[Union[nn.Module, str]] = "default",
interp_mode: str = "linear",
align_corners: Optional[bool] = True,
halves: bool = True,
):
"""
Expand All @@ -105,12 +108,30 @@ def __init__(
dropout: dropout ratio. Defaults to no dropout.
upsample: upsampling mode, available options are
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
pre_conv: a conv block applied before upsampling.
Only used in the "nontrainable" or "pixelshuffle" mode.
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
Only used in the "nontrainable" mode.
align_corners: set the align_corners parameter for upsample. Defaults to True.
Only used in the "nontrainable" mode.
halves: whether to halve the number of channels during upsampling.
This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
"""
super().__init__()

up_chns = in_chns // 2 if halves else in_chns
self.upsample = UpSample(dim, in_chns, up_chns, 2, mode=upsample)
if upsample == "nontrainable" and pre_conv is None:
up_chns = in_chns
else:
up_chns = in_chns // 2 if halves else in_chns
self.upsample = UpSample(
dim,
in_chns,
up_chns,
2,
mode=upsample,
pre_conv=pre_conv,
interp_mode=interp_mode,
align_corners=align_corners,
)
self.convs = TwoConv(dim, cat_chns + up_chns, out_chns, act, norm, dropout)

def forward(self, x: torch.Tensor, x_e: torch.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/dynunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DynUNetSkipLayer(nn.Module):
Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection.
The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet
structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on
looping over lists of layers and accumulating lists of output tensors which much be indexed. The `heads` list is
looping over lists of layers and accumulating lists of output tensors which must be indexed. The `heads` list is
shared amongst all the instances of this class and is used to store the output from the supervision heads during
forward passes of the network.
"""
Expand Down