Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions monai/networks/nets/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class AutoEncoder(nn.Module):
bias: whether to have a bias term in convolution blocks. Defaults to True.
According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
for each dimension in convolution blocks. Defaults to None.

Examples::

Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
norm: tuple | str = Norm.INSTANCE,
dropout: tuple | str | float | None = None,
bias: bool = True,
padding: Sequence[int] | int | None = None,
) -> None:
super().__init__()
self.dimensions = spatial_dims
Expand All @@ -118,6 +121,7 @@ def __init__(
self.norm = norm
self.dropout = dropout
self.bias = bias
self.padding = padding
self.num_inter_units = num_inter_units
self.inter_channels = inter_channels if inter_channels is not None else []
self.inter_dilations = list(inter_dilations or [1] * len(self.inter_channels))
Expand Down Expand Up @@ -178,6 +182,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> tu
dropout=self.dropout,
dilation=di,
bias=self.bias,
padding=self.padding,
)
else:
unit = Convolution(
Expand All @@ -191,6 +196,7 @@ def _get_intermediate_module(self, in_channels: int, num_inter_units: int) -> tu
dropout=self.dropout,
dilation=di,
bias=self.bias,
padding=self.padding,
)

intermediate.add_module("inter_%i" % i, unit)
Expand Down Expand Up @@ -231,6 +237,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
last_conv_only=is_last,
)
return mod
Expand All @@ -244,6 +251,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
conv_only=is_last,
)
return mod
Expand All @@ -264,6 +272,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
conv_only=is_last and self.num_res_units == 0,
is_transposed=True,
)
Expand All @@ -282,6 +291,7 @@ def _get_decode_layer(self, in_channels: int, out_channels: int, strides: int, i
norm=self.norm,
dropout=self.dropout,
bias=self.bias,
padding=self.padding,
last_conv_only=is_last,
)

Expand Down