diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index c69e9732..2de6705d 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -911,6 +911,7 @@ class CrossAttnDownBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ def __init__( @@ -930,6 +931,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -962,6 +964,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout=dropout_cattn ) ) @@ -1100,6 +1103,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() self.attention = None @@ -1123,6 +1127,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout=dropout_cattn ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, @@ -1266,7 +1271,7 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, - use_flash_attention: bool = False, + use_flash_attention: bool = False ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1363,6 +1368,7 @@ class CrossAttnUpBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ def __init__( @@ -1382,6 +1388,7 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1415,6 +1422,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout=dropout_cattn ) ) @@ -1478,6 +1486,7 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1509,6 +1518,7 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) else: return DownBlock( @@ -1536,6 +1546,7 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1549,6 +1560,7 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) else: return AttnMidBlock( @@ -1580,6 +1592,7 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1613,6 +1626,7 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) else: return UpBlock( @@ -1653,6 +1667,7 @@ class DiffusionModelUNet(nn.Module): classes. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ def __init__( @@ -1673,6 +1688,7 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + dropout_cattn: float = 0.0 ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1684,6 +1700,10 @@ def __init__( raise ValueError( "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError( + "Dropout cannot be negative or >1.0!" + ) # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): @@ -1773,6 +1793,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) self.down_blocks.append(down_block) @@ -1790,6 +1811,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) # up @@ -1824,6 +1846,7 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn ) self.up_blocks.append(up_block) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index b02c37b1..976e88d4 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -231,6 +231,59 @@ ], ] +DROPOUT_OK = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 0.25 + } + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3 + } + ], +] + +DROPOUT_WRONG = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "dropout_cattn": 3.0 + } + ], +] + class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -524,6 +577,17 @@ def test_script_conditioned_3d_models(self): net, torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)) ) + # Test dropout specification for cross-attention blocks + @parameterized.expand(DROPOUT_WRONG) + def test_wrong_dropout(self, input_param): + with self.assertRaises(ValueError): + _ = DiffusionModelUNet(**input_param) + + @parameterized.expand(DROPOUT_OK) + def test_right_dropout(self, input_param): + _ = DiffusionModelUNet(**input_param) + + if __name__ == "__main__": unittest.main()