Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# =========================================================================

import math
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -1453,7 +1453,7 @@ def __init__(
attention_levels: Sequence[bool] = (False, False, True, True),
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
num_head_channels: int = 8,
num_head_channels: Union[int, Sequence[int]] = 8,
with_conditioning: bool = False,
transformer_num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
Expand All @@ -1476,6 +1476,15 @@ def __init__(
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")

if isinstance(num_head_channels, int):
num_head_channels = (num_head_channels,) * len(attention_levels)

if len(num_head_channels) != len(attention_levels):
raise ValueError(
"num_head_channels should have the same length as attention_levels. For the i levels without attention,"
" i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
)

self.in_channels = in_channels
self.block_out_channels = num_channels
self.out_channels = out_channels
Expand Down Expand Up @@ -1526,7 +1535,7 @@ def __init__(
add_downsample=not is_final_block,
with_attn=(attention_levels[i] and not with_conditioning),
with_cross_attn=(attention_levels[i] and with_conditioning),
num_head_channels=num_head_channels,
num_head_channels=num_head_channels[i],
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
)
Expand All @@ -1541,7 +1550,7 @@ def __init__(
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
with_conditioning=with_conditioning,
num_head_channels=num_head_channels,
num_head_channels=num_head_channels[-1],
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
)
Expand All @@ -1550,6 +1559,7 @@ def __init__(
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(num_channels))
reversed_attention_levels = list(reversed(attention_levels))
reversed_num_head_channels = list(reversed(num_head_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(reversed_block_out_channels)):
prev_output_channel = output_channel
Expand All @@ -1570,7 +1580,7 @@ def __init__(
add_upsample=not is_final_block,
with_attn=(reversed_attention_levels[i] and not with_conditioning),
with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
num_head_channels=num_head_channels,
num_head_channels=reversed_num_head_channels[i],
transformer_num_layers=transformer_num_layers,
cross_attention_dim=cross_attention_dim,
)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@
"norm_num_groups": 8,
},
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, True, True),
"num_head_channels": (0, 2, 4),
"norm_num_groups": 8,
},
],
]

UNCOND_CASES_3D = [
Expand Down Expand Up @@ -92,6 +104,18 @@
"norm_num_groups": 8,
},
],
[
{
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": (0, 0, 4),
"norm_num_groups": 8,
},
],
]


Expand Down Expand Up @@ -131,6 +155,19 @@ def test_model_channels_not_multiple_of_norm_num_group(self):
norm_num_groups=8,
)

def test_attention_levels_with_different_length_num_head_channels(self):
with self.assertRaises(ValueError):
net = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_res_blocks=1,
num_channels=(8, 8, 8),
attention_levels=(False, False, False),
num_head_channels=(0, 2),
norm_num_groups=8,
)

def test_shape_conditioned_models(self):
net = DiffusionModelUNet(
spatial_dims=2,
Expand Down