From 40c3a24e79f9a476fc0638ce86911eb254bea1e6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 22 Jul 2021 11:05:18 +0800 Subject: [PATCH 1/7] [DLMED] enhance doc-string for spatial shape Signed-off-by: Nic Ma --- monai/networks/nets/unet.py | 6 ++++++ tests/test_unet.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index f3742d05b5..76a5348e1c 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -57,6 +57,12 @@ def __init__( act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. + + Note: Usually, UNet will decrease the spatial size of feature maps in every down block based on + the `strides` arg, please ensure the spatial size of input data can be divided by the product of `strides`, + for example, `strides=[2, 2, 2]` and input data shape can be `[16, 1, 16, 8]`. Typically, applying + `resize`, `pad` or `crop` transforms to adjust the spatial size. + """ super().__init__() diff --git a/tests/test_unet.py b/tests/test_unet.py index 7bf2c0c920..461a91b58b 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -141,7 +141,7 @@ def test_script_without_running_stats(self): num_res_units=0, norm=("batch", {"track_running_stats": False}), ) - test_data = torch.randn(16, 1, 16, 8) + test_data = torch.randn(16, 1, 16, 4) test_script_save(net, test_data) From 990870517a0d5ba4517cf855ebc3098611dd53cc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 22 Jul 2021 14:53:00 +0800 Subject: [PATCH 2/7] [DLMED] add ill test case Signed-off-by: Nic Ma --- tests/test_unet.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_unet.py b/tests/test_unet.py index 461a91b58b..a0ba06e513 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -144,6 +144,18 @@ def test_script_without_running_stats(self): test_data = torch.randn(16, 1, 16, 4) test_script_save(net, test_data) + def test_ill_input_shape(self): + net = UNet( + dimensions=2, + in_channels=1, + out_channels=3, + channels=(16, 32, 64), + strides=(2, 2), + ) + with eval_mode(net): + with self.assertRaisesRegex(RuntimeError, r"torch.cat\(\): Sizes of tensors must match"): + net.forward(torch.randn(2, 1, 16, 5)) + if __name__ == "__main__": unittest.main() From 521867dec42edce1ca1216673bc859733d451d64 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 22 Jul 2021 16:28:23 +0800 Subject: [PATCH 3/7] [DLMED] update test error Signed-off-by: Nic Ma --- tests/test_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_unet.py b/tests/test_unet.py index a0ba06e513..d7029043ac 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -153,7 +153,7 @@ def test_ill_input_shape(self): strides=(2, 2), ) with eval_mode(net): - with self.assertRaisesRegex(RuntimeError, r"torch.cat\(\): Sizes of tensors must match"): + with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"): net.forward(torch.randn(2, 1, 16, 5)) From a9361ff34960e47480387b6494aad8e46e85323a Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 26 Jul 2021 14:30:56 +0800 Subject: [PATCH 4/7] update docstring and add more test cases Signed-off-by: Yiheng Wang --- monai/networks/nets/unet.py | 33 +++++++++++++++++++------ tests/test_unet.py | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 76a5348e1c..07a3a7c5ff 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Union +import warnings +from typing import Sequence, Tuple, Union import torch import torch.nn as nn @@ -35,8 +36,8 @@ def __init__( kernel_size: Union[Sequence[int], int] = 3, up_kernel_size: Union[Sequence[int], int] = 3, num_res_units: int = 0, - act=Act.PRELU, - norm=Norm.INSTANCE, + act: Union[Tuple, str] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, dropout=0.0, ) -> None: """ @@ -49,10 +50,12 @@ def __init__( dimensions: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. - channels: sequence of channels. Top block first. - strides: convolution stride. - kernel_size: convolution kernel size. Defaults to 3. - up_kernel_size: upsampling convolution kernel size. Defaults to 3. + channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. + strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. + kernel_size: convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. + up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, + its length should equal to dimensions. Defaults to 3. num_res_units: number of residual units. Defaults to 0. act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. @@ -66,6 +69,22 @@ def __init__( """ super().__init__() + if len(channels) < 2: + raise ValueError("the length of `channels` should be no less than 2.") + if len(strides) < len(channels) - 1: + raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") + if len(strides) >= len(channels): + warn_msg = "`len(strides) >= len(channels)`, the stride values {} in {} will not be used.".format( + strides[len(channels) - 1 :], strides + ) + warnings.warn(warn_msg) + if isinstance(kernel_size, Sequence): + if len(kernel_size) != dimensions: + raise ValueError("the length of `kernel_size` should equal to `dimensions`.") + if isinstance(up_kernel_size, Sequence): + if len(up_kernel_size) != dimensions: + raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") + self.dimensions = dimensions self.in_channels = in_channels self.out_channels = out_channels diff --git a/tests/test_unet.py b/tests/test_unet.py index d7029043ac..4091c4e9d7 100644 --- a/tests/test_unet.py +++ b/tests/test_unet.py @@ -117,6 +117,49 @@ CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] +ILL_CASES = [ + [ + { # len(channels) < 2 + "dimensions": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (16,), + "strides": (2, 2), + "num_res_units": 0, + } + ], + [ + { # len(strides) < len(channels) - 1 + "dimensions": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (8, 8, 8), + "strides": (2,), + "num_res_units": 0, + } + ], + [ + { # len(kernel_size) = 3, dimensions = 2 + "dimensions": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (8, 8, 8), + "strides": (2, 2), + "kernel_size": (3, 3, 3), + } + ], + [ + { # len(up_kernel_size) = 2, dimensions = 3 + "dimensions": 3, + "in_channels": 1, + "out_channels": 3, + "channels": (8, 8, 8), + "strides": (2, 2), + "up_kernel_size": (3, 3), + } + ], +] + class TestUNET(unittest.TestCase): @parameterized.expand(CASES) @@ -156,6 +199,11 @@ def test_ill_input_shape(self): with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"): net.forward(torch.randn(2, 1, 16, 5)) + @parameterized.expand(ILL_CASES) + def test_ill_input_hyper_params(self, input_param): + with self.assertRaises(ValueError): + net = UNet(**input_param) + if __name__ == "__main__": unittest.main() From 292138ee5e38005bafd5c19fa89335af96239d61 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 26 Jul 2021 23:38:27 +0800 Subject: [PATCH 5/7] [DLMED] add link to doc-string Signed-off-by: Nic Ma --- monai/networks/nets/unet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 07a3a7c5ff..3803c55c57 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -61,10 +61,10 @@ def __init__( norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. - Note: Usually, UNet will decrease the spatial size of feature maps in every down block based on - the `strides` arg, please ensure the spatial size of input data can be divided by the product of `strides`, - for example, `strides=[2, 2, 2]` and input data shape can be `[16, 1, 16, 8]`. Typically, applying - `resize`, `pad` or `crop` transforms to adjust the spatial size. + Note: The acceptable spatial size of input data depends on the parameters of the network, + to set appropriate spatial size, please check the tutorial for more details: + https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. + Typically, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. """ super().__init__() From 217b9c1c24e621e4cba7ed02f39ebfa88937a416 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 26 Jul 2021 23:47:34 +0800 Subject: [PATCH 6/7] [DLMED] enhance doc-string Signed-off-by: Nic Ma --- monai/networks/nets/unet.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 3803c55c57..96245dd186 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -71,13 +71,11 @@ def __init__( if len(channels) < 2: raise ValueError("the length of `channels` should be no less than 2.") - if len(strides) < len(channels) - 1: + delta = len(strides) - len(channels) + if delta < - 1: raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") - if len(strides) >= len(channels): - warn_msg = "`len(strides) >= len(channels)`, the stride values {} in {} will not be used.".format( - strides[len(channels) - 1 :], strides - ) - warnings.warn(warn_msg) + if delta >= 0: + warnings.warn(f"`len(strides) >= len(channels)`, the last {delta + 1} values of strides will not be used.") if isinstance(kernel_size, Sequence): if len(kernel_size) != dimensions: raise ValueError("the length of `kernel_size` should equal to `dimensions`.") From 6a00bce71586d068f4fb78f859be5720db2dba0e Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 26 Jul 2021 15:52:27 +0000 Subject: [PATCH 7/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/networks/nets/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index 96245dd186..9e3538c2b3 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -72,7 +72,7 @@ def __init__( if len(channels) < 2: raise ValueError("the length of `channels` should be no less than 2.") delta = len(strides) - len(channels) - if delta < - 1: + if delta < -1: raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") if delta >= 0: warnings.warn(f"`len(strides) >= len(channels)`, the last {delta + 1} values of strides will not be used.")