Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion generative/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,7 @@ class CrossAttnDownBlock(nn.Module):
cross_attention_dim: number of context dimensions to use.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
"""

def __init__(
Expand All @@ -930,6 +931,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -962,6 +964,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
)
)

Expand Down Expand Up @@ -1100,6 +1103,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> None:
super().__init__()
self.attention = None
Expand All @@ -1123,6 +1127,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
)
self.resnet_2 = ResnetBlock(
spatial_dims=spatial_dims,
Expand Down Expand Up @@ -1266,7 +1271,7 @@ def __init__(
add_upsample: bool = True,
resblock_updown: bool = False,
num_head_channels: int = 1,
use_flash_attention: bool = False,
use_flash_attention: bool = False
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -1363,6 +1368,7 @@ class CrossAttnUpBlock(nn.Module):
cross_attention_dim: number of context dimensions to use.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
"""

def __init__(
Expand All @@ -1382,6 +1388,7 @@ def __init__(
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> None:
super().__init__()
self.resblock_updown = resblock_updown
Expand Down Expand Up @@ -1415,6 +1422,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout=dropout_cattn
)
)

Expand Down Expand Up @@ -1478,6 +1486,7 @@ def get_down_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> nn.Module:
if with_attn:
return AttnDownBlock(
Expand Down Expand Up @@ -1509,6 +1518,7 @@ def get_down_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
)
else:
return DownBlock(
Expand Down Expand Up @@ -1536,6 +1546,7 @@ def get_mid_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> nn.Module:
if with_conditioning:
return CrossAttnMidBlock(
Expand All @@ -1549,6 +1560,7 @@ def get_mid_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
)
else:
return AttnMidBlock(
Expand Down Expand Up @@ -1580,6 +1592,7 @@ def get_up_block(
cross_attention_dim: int | None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> nn.Module:
if with_attn:
return AttnUpBlock(
Expand Down Expand Up @@ -1613,6 +1626,7 @@ def get_up_block(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
)
else:
return UpBlock(
Expand Down Expand Up @@ -1653,6 +1667,7 @@ class DiffusionModelUNet(nn.Module):
classes.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers
"""

def __init__(
Expand All @@ -1673,6 +1688,7 @@ def __init__(
num_class_embeds: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
dropout_cattn: float = 0.0
) -> None:
super().__init__()
if with_conditioning is True and cross_attention_dim is None:
Expand All @@ -1684,6 +1700,10 @@ def __init__(
raise ValueError(
"DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
)
if dropout_cattn > 1.0 or dropout_cattn < 0.0:
raise ValueError(
"Dropout cannot be negative or >1.0!"
)

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
Expand Down Expand Up @@ -1773,6 +1793,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
)

self.down_blocks.append(down_block)
Expand All @@ -1790,6 +1811,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
)

# up
Expand Down Expand Up @@ -1824,6 +1846,7 @@ def __init__(
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
use_flash_attention=use_flash_attention,
dropout_cattn=dropout_cattn
)

self.up_blocks.append(up_block)
Expand Down
64 changes: 64 additions & 0 deletions tests/test_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,59 @@
],
]

DROPOUT_OK = [
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 4,
"norm_num_groups": 8,
"with_conditioning": True,
"transformer_num_layers": 1,
"cross_attention_dim": 3,
"dropout_cattn": 0.25
}
],
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 4,
"norm_num_groups": 8,
"with_conditioning": True,
"transformer_num_layers": 1,
"cross_attention_dim": 3
}
],
]

DROPOUT_WRONG = [
[
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 1,
"num_res_blocks": 1,
"num_channels": (8, 8, 8),
"attention_levels": (False, False, True),
"num_head_channels": 4,
"norm_num_groups": 8,
"with_conditioning": True,
"transformer_num_layers": 1,
"cross_attention_dim": 3,
"dropout_cattn": 3.0
}
],
]


class TestDiffusionModelUNet2D(unittest.TestCase):
@parameterized.expand(UNCOND_CASES_2D)
Expand Down Expand Up @@ -524,6 +577,17 @@ def test_script_conditioned_3d_models(self):
net, torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))
)

# Test dropout specification for cross-attention blocks
@parameterized.expand(DROPOUT_WRONG)
def test_wrong_dropout(self, input_param):
with self.assertRaises(ValueError):
_ = DiffusionModelUNet(**input_param)

@parameterized.expand(DROPOUT_OK)
def test_right_dropout(self, input_param):
_ = DiffusionModelUNet(**input_param)



if __name__ == "__main__":
unittest.main()