From 6d8190b9b26944e9a9ff985f20addc665d6e7e34 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 4 Feb 2023 14:01:51 +0000 Subject: [PATCH] Add verification code Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 3 +++ tests/test_diffusion_model_unet.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 8d13b76f..76caa08b 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1636,6 +1636,9 @@ def __init__( if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels") + if isinstance(num_head_channels, int): num_head_channels = (num_head_channels,) * len(attention_levels) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 7e116ea6..f34b4a3f 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -360,6 +360,20 @@ def test_conditioned_models_no_class_labels(self): ) net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + def test_model_num_channels_not_same_size_of_attention_levels(self): + with self.assertRaises(ValueError): + DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False), + norm_num_groups=8, + num_head_channels=8, + num_class_embeds=2, + ) + def test_script_unconditioned_2d_models(self): net = DiffusionModelUNet( spatial_dims=2,