Skip to content
37 changes: 30 additions & 7 deletions monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence, Union
import warnings
from typing import Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -35,8 +36,8 @@ def __init__(
kernel_size: Union[Sequence[int], int] = 3,
up_kernel_size: Union[Sequence[int], int] = 3,
num_res_units: int = 0,
act=Act.PRELU,
norm=Norm.INSTANCE,
act: Union[Tuple, str] = Act.PRELU,
norm: Union[Tuple, str] = Norm.INSTANCE,
dropout=0.0,
) -> None:
"""
Expand All @@ -49,17 +50,39 @@ def __init__(
dimensions: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
channels: sequence of channels. Top block first.
strides: convolution stride.
kernel_size: convolution kernel size. Defaults to 3.
up_kernel_size: upsampling convolution kernel size. Defaults to 3.
channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.
strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.
kernel_size: convolution kernel size, the value(s) should be odd. If sequence,
its length should equal to dimensions. Defaults to 3.
up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,
its length should equal to dimensions. Defaults to 3.
num_res_units: number of residual units. Defaults to 0.
act: activation type and arguments. Defaults to PReLU.
norm: feature normalization type and arguments. Defaults to instance norm.
dropout: dropout ratio. Defaults to no dropout.

Note: The acceptable spatial size of input data depends on the parameters of the network,
to set appropriate spatial size, please check the tutorial for more details:
https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb.
Typically, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data.

"""
super().__init__()

if len(channels) < 2:
raise ValueError("the length of `channels` should be no less than 2.")
delta = len(strides) - len(channels)
if delta < -1:
raise ValueError("the length of `strides` should equal to `len(channels) - 1`.")
if delta >= 0:
warnings.warn(f"`len(strides) >= len(channels)`, the last {delta + 1} values of strides will not be used.")
if isinstance(kernel_size, Sequence):
if len(kernel_size) != dimensions:
raise ValueError("the length of `kernel_size` should equal to `dimensions`.")
if isinstance(up_kernel_size, Sequence):
if len(up_kernel_size) != dimensions:
raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.")

self.dimensions = dimensions
self.in_channels = in_channels
self.out_channels = out_channels
Expand Down
62 changes: 61 additions & 1 deletion tests/test_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,49 @@

CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]

ILL_CASES = [
[
{ # len(channels) < 2
"dimensions": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (16,),
"strides": (2, 2),
"num_res_units": 0,
}
],
[
{ # len(strides) < len(channels) - 1
"dimensions": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (8, 8, 8),
"strides": (2,),
"num_res_units": 0,
}
],
[
{ # len(kernel_size) = 3, dimensions = 2
"dimensions": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (8, 8, 8),
"strides": (2, 2),
"kernel_size": (3, 3, 3),
}
],
[
{ # len(up_kernel_size) = 2, dimensions = 3
"dimensions": 3,
"in_channels": 1,
"out_channels": 3,
"channels": (8, 8, 8),
"strides": (2, 2),
"up_kernel_size": (3, 3),
}
],
]


class TestUNET(unittest.TestCase):
@parameterized.expand(CASES)
Expand All @@ -141,9 +184,26 @@ def test_script_without_running_stats(self):
num_res_units=0,
norm=("batch", {"track_running_stats": False}),
)
test_data = torch.randn(16, 1, 16, 8)
test_data = torch.randn(16, 1, 16, 4)
test_script_save(net, test_data)

def test_ill_input_shape(self):
net = UNet(
dimensions=2,
in_channels=1,
out_channels=3,
channels=(16, 32, 64),
strides=(2, 2),
)
with eval_mode(net):
with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"):
net.forward(torch.randn(2, 1, 16, 5))

@parameterized.expand(ILL_CASES)
def test_ill_input_hyper_params(self, input_param):
with self.assertRaises(ValueError):
net = UNet(**input_param)


if __name__ == "__main__":
unittest.main()