diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index b5ec38d8..7eb80b8c 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -116,6 +116,7 @@ class CrossAttention(nn.Module): num_attention_heads: number of heads to use for multi-head attention. num_head_channels: number of channels in each head. dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -125,6 +126,7 @@ def __init__( num_attention_heads: int = 8, num_head_channels: int = 64, dropout: float = 0.0, + upcast_attention: bool = False, ) -> None: super().__init__() inner_dim = num_head_channels * num_attention_heads @@ -133,6 +135,8 @@ def __init__( self.scale = 1 / math.sqrt(num_head_channels) self.num_heads = num_attention_heads + self.upcast_attention = upcast_attention + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) @@ -165,6 +169,11 @@ def _memory_efficient_attention_xformers( return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, @@ -173,6 +182,8 @@ def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor alpha=self.scale, ) attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + x = torch.bmm(attention_probs, value) return x @@ -208,6 +219,7 @@ class BasicTransformerBlock(nn.Module): num_head_channels: number of channels in each attention head. dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -217,6 +229,7 @@ def __init__( num_head_channels: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: super().__init__() self.attn1 = CrossAttention( @@ -224,6 +237,7 @@ def __init__( num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, + upcast_attention=upcast_attention, ) # is a self-attention self.ff = FeedForward(num_channels, dropout=dropout) self.attn2 = CrossAttention( @@ -232,6 +246,7 @@ def __init__( num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, + upcast_attention=upcast_attention, ) # is a self-attention if context is None self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) @@ -264,6 +279,7 @@ class SpatialTransformer(nn.Module): norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -277,6 +293,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -303,6 +320,7 @@ def __init__( num_head_channels=num_head_channels, dropout=dropout, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) for _ in range(num_layers) ] @@ -708,6 +726,22 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + def __init__( self, spatial_dims: int, @@ -721,21 +755,6 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, ) -> None: - """ - Unet's down block containing resnet and downsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - """ super().__init__() self.resblock_updown = resblock_updown @@ -796,6 +815,23 @@ def forward( class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + """ + def __init__( self, spatial_dims: int, @@ -810,22 +846,6 @@ def __init__( downsample_padding: int = 1, num_head_channels: int = 1, ) -> None: - """ - Unet's down block containing resnet, downsamplers and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - """ super().__init__() self.resblock_updown = resblock_updown @@ -898,6 +918,26 @@ def forward( class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + def __init__( self, spatial_dims: int, @@ -913,25 +953,8 @@ def __init__( num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: - """ - Unet's down block containing resnet, downsamplers and cross-attention blocks. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - """ super().__init__() self.resblock_updown = resblock_updown @@ -961,6 +984,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) ) @@ -1007,6 +1031,18 @@ def forward( class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + """ + def __init__( self, spatial_dims: int, @@ -1016,17 +1052,6 @@ def __init__( norm_eps: float = 1e-6, num_head_channels: int = 1, ) -> None: - """ - Unet's mid block containing resnet and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - """ super().__init__() self.attention = None @@ -1067,6 +1092,21 @@ def forward( class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + def __init__( self, spatial_dims: int, @@ -1077,20 +1117,8 @@ def __init__( num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: - """ - Unet's mid block containing resnet and cross-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - """ super().__init__() self.attention = None @@ -1111,6 +1139,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, @@ -1132,6 +1161,22 @@ def forward( class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + def __init__( self, spatial_dims: int, @@ -1145,21 +1190,6 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, ) -> None: - """ - Unet's up block containing resnet and upsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - """ super().__init__() self.resblock_updown = resblock_updown resnets = [] @@ -1222,6 +1252,23 @@ def forward( class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + """ + def __init__( self, spatial_dims: int, @@ -1236,22 +1283,6 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, ) -> None: - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - """ super().__init__() self.resblock_updown = resblock_updown @@ -1327,6 +1358,26 @@ def forward( class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + def __init__( self, spatial_dims: int, @@ -1342,25 +1393,8 @@ def __init__( num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - """ super().__init__() self.resblock_updown = resblock_updown @@ -1391,6 +1425,7 @@ def __init__( norm_eps=norm_eps, num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) ) @@ -1452,6 +1487,7 @@ def get_down_block( num_head_channels: int, transformer_num_layers: int, cross_attention_dim: Optional[int], + upcast_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1480,6 +1516,7 @@ def get_down_block( num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) else: return DownBlock( @@ -1505,6 +1542,7 @@ def get_mid_block( num_head_channels: int, transformer_num_layers: int, cross_attention_dim: Optional[int], + upcast_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1516,6 +1554,7 @@ def get_mid_block( num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) else: return AttnMidBlock( @@ -1544,6 +1583,7 @@ def get_up_block( num_head_channels: int, transformer_num_layers: int, cross_attention_dim: Optional[int], + upcast_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1574,6 +1614,7 @@ def get_up_block( num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) else: return UpBlock( @@ -1612,6 +1653,7 @@ class DiffusionModelUNet(nn.Module): cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -1630,6 +1672,7 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1712,6 +1755,7 @@ def __init__( num_head_channels=num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) self.down_blocks.append(down_block) @@ -1727,6 +1771,7 @@ def __init__( num_head_channels=num_head_channels[-1], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) # up @@ -1758,6 +1803,7 @@ def __init__( num_head_channels=reversed_num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) self.up_blocks.append(up_block) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 255e2fc9..a5f387d0 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -168,6 +168,56 @@ ], ] +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + ], +] + class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -340,21 +390,12 @@ def test_script_conditioned_2d_models(self): ) test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) - def test_script_conditioned_2d_models_with_resblock_updown(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=3, - resblock_updown=True, - ) - test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) class TestDiffusionModelUNet3D(unittest.TestCase):