diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index f3742d05b5..9e3538c2b3 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +import warnings +from typing import Sequence, Tuple, Union import torch import torch.nn as nn @@ -35,8 +36,8 @@ def __init__( kernel_size: Union[Sequence[int], int] = 3, up_kernel_size: Union[Sequence[int], int] = 3, num_res_units: int = 0, - act=Act.PRELU, - norm=Norm.INSTANCE, + act: Union[Tuple, str] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, dropout=0.0, ) -> None: """ @@ -49,17 +50,39 @@ def __init__( dimensions: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. - channels: sequence of channels. Top block first. - strides: convolution stride. - kernel_size: convolution kernel size. Defaults to 3. - up_kernel_size: upsampling convolution kernel size. Defaults to 3. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. num_res_units: number of residual units. Defaults to 0. act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. + + Note: The acceptable spatial size of input data depends on the parameters of the network, + to set appropriate spatial size, please check the tutorial for more details: + https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. + Typically, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. + """ super().__init__() + if len(channels) < 2: + raise ValueError("the length of `channels` should be no less than 2.") + delta = len(strides) - len(channels) + if delta < -1: + raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") + if delta >= 0: + warnings.warn(f"`len(strides) >= len(channels)`, the last {delta + 1} values of strides will not be used.") + if isinstance(kernel_size, Sequence): + if len(kernel_size) != dimensions: + raise ValueError("the length of `kernel_size` should equal to `dimensions`.") + if isinstance(up_kernel_size, Sequence): + if len(up_kernel_size) != dimensions: + raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") + self.dimensions = dimensions self.in_channels = in_channels self.out_channels = out_channels diff --git a/tests/test_unet.py b/tests/test_unet.py index 7bf2c0c920..4091c4e9d7 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -117,6 +117,49 @@ CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] +ILL_CASES = [ + [ + { # len(channels) < 2 + "dimensions": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (16,), + "strides": (2, 2), + "num_res_units": 0, + } + ], + [ + { # len(strides) < len(channels) - 1 + "dimensions": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (8, 8, 8), + "strides": (2,), + "num_res_units": 0, + } + ], + [ + { # len(kernel_size) = 3, dimensions = 2 + "dimensions": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (8, 8, 8), + "strides": (2, 2), + "kernel_size": (3, 3, 3), + } + ], + [ + { # len(up_kernel_size) = 2, dimensions = 3 + "dimensions": 3, + "in_channels": 1, + "out_channels": 3, + "channels": (8, 8, 8), + "strides": (2, 2), + "up_kernel_size": (3, 3), + } + ], +] + class TestUNET(unittest.TestCase): @parameterized.expand(CASES) @@ -141,9 +184,26 @@ def test_script_without_running_stats(self): num_res_units=0, norm=("batch", {"track_running_stats": False}), ) - test_data = torch.randn(16, 1, 16, 8) + test_data = torch.randn(16, 1, 16, 4) test_script_save(net, test_data) + def test_ill_input_shape(self): + net = UNet( + dimensions=2, + in_channels=1, + out_channels=3, + channels=(16, 32, 64), + strides=(2, 2), + ) + with eval_mode(net): + with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"): + net.forward(torch.randn(2, 1, 16, 5)) + + @parameterized.expand(ILL_CASES) + def test_ill_input_hyper_params(self, input_param): + with self.assertRaises(ValueError): + net = UNet(**input_param) + if __name__ == "__main__": unittest.main()