diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7aabfbba..5974594b 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -30,7 +30,7 @@ # ========================================================================= import math -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F @@ -1453,7 +1453,7 @@ def __init__( attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, - num_head_channels: int = 8, + num_head_channels: Union[int, Sequence[int]] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, @@ -1476,6 +1476,15 @@ def __init__( if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + if isinstance(num_head_channels, int): + num_head_channels = (num_head_channels,) * len(attention_levels) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + self.in_channels = in_channels self.block_out_channels = num_channels self.out_channels = out_channels @@ -1526,7 +1535,7 @@ def __init__( add_downsample=not is_final_block, with_attn=(attention_levels[i] and not with_conditioning), with_cross_attn=(attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels, + num_head_channels=num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, ) @@ -1541,7 +1550,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, with_conditioning=with_conditioning, - num_head_channels=num_head_channels, + num_head_channels=num_head_channels[-1], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, ) @@ -1550,6 +1559,7 @@ def __init__( self.up_blocks = nn.ModuleList([]) reversed_block_out_channels = list(reversed(num_channels)) reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) output_channel = reversed_block_out_channels[0] for i in range(len(reversed_block_out_channels)): prev_output_channel = output_channel @@ -1570,7 +1580,7 @@ def __init__( add_upsample=not is_final_block, with_attn=(reversed_attention_levels[i] and not with_conditioning), with_cross_attn=(reversed_attention_levels[i] and with_conditioning), - num_head_channels=num_head_channels, + num_head_channels=reversed_num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, ) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index f9925b96..07f5f4f8 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -54,6 +54,18 @@ "norm_num_groups": 8, }, ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, True, True), + "num_head_channels": (0, 2, 4), + "norm_num_groups": 8, + }, + ], ] UNCOND_CASES_3D = [ @@ -92,6 +104,18 @@ "norm_num_groups": 8, }, ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": (0, 0, 4), + "norm_num_groups": 8, + }, + ], ] @@ -131,6 +155,19 @@ def test_model_channels_not_multiple_of_norm_num_group(self): norm_num_groups=8, ) + def test_attention_levels_with_different_length_num_head_channels(self): + with self.assertRaises(ValueError): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, False), + num_head_channels=(0, 2), + norm_num_groups=8, + ) + def test_shape_conditioned_models(self): net = DiffusionModelUNet( spatial_dims=2,