diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 9e3538c2b3..158b154042 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -64,18 +64,21 @@ def __init__( Note: The acceptable spatial size of input data depends on the parameters of the network, to set appropriate spatial size, please check the tutorial for more details: https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. - Typically, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. + Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the + input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network, + the inputs must have spatial dimensions that are all multiples of 2^N. + Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. """ super().__init__() if len(channels) < 2: raise ValueError("the length of `channels` should be no less than 2.") - delta = len(strides) - len(channels) - if delta < -1: + delta = len(strides) - (len(channels) - 1) + if delta < 0: raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") - if delta >= 0: - warnings.warn(f"`len(strides) >= len(channels)`, the last {delta + 1} values of strides will not be used.") + if delta > 0: + warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.") if isinstance(kernel_size, Sequence): if len(kernel_size) != dimensions: raise ValueError("the length of `kernel_size` should equal to `dimensions`.")