From fc4d6c840c3236fce1ac3e74420f5daf3a85eb66 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 14 Jan 2023 22:30:09 +0000 Subject: [PATCH 1/3] Add upcast_attention option for the crossattention Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index b5ec38d8..45c30c03 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -116,6 +116,7 @@ class CrossAttention(nn.Module): num_attention_heads: number of heads to use for multi-head attention. num_head_channels: number of channels in each head. dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -125,6 +126,7 @@ def __init__( num_attention_heads: int = 8, num_head_channels: int = 64, dropout: float = 0.0, + upcast_attention: bool = False, ) -> None: super().__init__() inner_dim = num_head_channels * num_attention_heads @@ -133,6 +135,8 @@ def __init__( self.scale = 1 / math.sqrt(num_head_channels) self.num_heads = num_attention_heads + self.upcast_attention = upcast_attention + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) @@ -165,6 +169,11 @@ def _memory_efficient_attention_xformers( return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, @@ -173,6 +182,8 @@ def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor alpha=self.scale, ) attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + x = torch.bmm(attention_probs, value) return x @@ -208,6 +219,7 @@ class BasicTransformerBlock(nn.Module): num_head_channels: number of channels in each attention head. dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -217,6 +229,7 @@ def __init__( num_head_channels: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: super().__init__() self.attn1 = CrossAttention( @@ -224,6 +237,7 @@ def __init__( num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, + upcast_attention=upcast_attention, ) # is a self-attention self.ff = FeedForward(num_channels, dropout=dropout) self.attn2 = CrossAttention( @@ -232,6 +246,7 @@ def __init__( num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, + upcast_attention=upcast_attention, ) # is a self-attention if context is None self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) @@ -264,6 +279,7 @@ class SpatialTransformer(nn.Module): norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -277,6 +293,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -303,6 +320,7 @@ def __init__( num_head_channels=num_head_channels, dropout=dropout, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) for _ in range(num_layers) ] @@ -913,6 +931,7 @@ def __init__( num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: """ Unet's down block containing resnet, downsamplers and cross-attention blocks. @@ -931,6 +950,7 @@ def __init__( num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. """ super().__init__() self.resblock_updown = resblock_updown @@ -961,6 +981,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) ) @@ -1077,6 +1098,7 @@ def __init__( num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: """ Unet's mid block containing resnet and cross-attention blocks. @@ -1090,6 +1112,7 @@ def __init__( num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. """ super().__init__() self.attention = None @@ -1111,6 +1134,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) self.resnet_2 = ResnetBlock( spatial_dims=spatial_dims, @@ -1342,6 +1366,7 @@ def __init__( num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, ) -> None: """ Unet's up block containing resnet, upsamplers, and self-attention blocks. @@ -1360,6 +1385,7 @@ def __init__( num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. """ super().__init__() self.resblock_updown = resblock_updown @@ -1391,6 +1417,7 @@ def __init__( norm_eps=norm_eps, num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) ) @@ -1452,6 +1479,7 @@ def get_down_block( num_head_channels: int, transformer_num_layers: int, cross_attention_dim: Optional[int], + upcast_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1480,6 +1508,7 @@ def get_down_block( num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) else: return DownBlock( @@ -1505,6 +1534,7 @@ def get_mid_block( num_head_channels: int, transformer_num_layers: int, cross_attention_dim: Optional[int], + upcast_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1516,6 +1546,7 @@ def get_mid_block( num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) else: return AttnMidBlock( @@ -1544,6 +1575,7 @@ def get_up_block( num_head_channels: int, transformer_num_layers: int, cross_attention_dim: Optional[int], + upcast_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1574,6 +1606,7 @@ def get_up_block( num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) else: return UpBlock( @@ -1612,6 +1645,7 @@ class DiffusionModelUNet(nn.Module): cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. """ def __init__( @@ -1630,6 +1664,7 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1712,6 +1747,7 @@ def __init__( num_head_channels=num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) self.down_blocks.append(down_block) @@ -1727,6 +1763,7 @@ def __init__( num_head_channels=num_head_channels[-1], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) # up @@ -1758,6 +1795,7 @@ def __init__( num_head_channels=reversed_num_head_channels[i], transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, ) self.up_blocks.append(up_block) From 312d4ab9a2db448381ba19d78447dcdfd69fb3a5 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 14 Jan 2023 22:34:12 +0000 Subject: [PATCH 2/3] Reposition docstring Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 258 +++++++++--------- 1 file changed, 133 insertions(+), 125 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 45c30c03..7eb80b8c 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -726,6 +726,22 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + def __init__( self, spatial_dims: int, @@ -739,21 +755,6 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, ) -> None: - """ - Unet's down block containing resnet and downsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - """ super().__init__() self.resblock_updown = resblock_updown @@ -814,6 +815,23 @@ def forward( class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + """ + def __init__( self, spatial_dims: int, @@ -828,22 +846,6 @@ def __init__( downsample_padding: int = 1, num_head_channels: int = 1, ) -> None: - """ - Unet's down block containing resnet, downsamplers and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - """ super().__init__() self.resblock_updown = resblock_updown @@ -916,6 +918,26 @@ def forward( class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + def __init__( self, spatial_dims: int, @@ -933,25 +955,6 @@ def __init__( cross_attention_dim: Optional[int] = None, upcast_attention: bool = False, ) -> None: - """ - Unet's down block containing resnet, downsamplers and cross-attention blocks. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_downsample: if True add downsample block. - resblock_updown: if True use residual blocks for downsampling. - downsample_padding: padding used in the downsampling block. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - """ super().__init__() self.resblock_updown = resblock_updown @@ -1028,6 +1031,18 @@ def forward( class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + """ + def __init__( self, spatial_dims: int, @@ -1037,17 +1052,6 @@ def __init__( norm_eps: float = 1e-6, num_head_channels: int = 1, ) -> None: - """ - Unet's mid block containing resnet and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - """ super().__init__() self.attention = None @@ -1088,6 +1092,21 @@ def forward( class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + def __init__( self, spatial_dims: int, @@ -1100,20 +1119,6 @@ def __init__( cross_attention_dim: Optional[int] = None, upcast_attention: bool = False, ) -> None: - """ - Unet's mid block containing resnet and cross-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - temb_channels: number of timestep embedding channels - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - """ super().__init__() self.attention = None @@ -1156,6 +1161,22 @@ def forward( class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + def __init__( self, spatial_dims: int, @@ -1169,21 +1190,6 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, ) -> None: - """ - Unet's up block containing resnet and upsamplers blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - """ super().__init__() self.resblock_updown = resblock_updown resnets = [] @@ -1246,6 +1252,23 @@ def forward( class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + """ + def __init__( self, spatial_dims: int, @@ -1260,22 +1283,6 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, ) -> None: - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - """ super().__init__() self.resblock_updown = resblock_updown @@ -1351,6 +1358,26 @@ def forward( class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + """ + def __init__( self, spatial_dims: int, @@ -1368,25 +1395,6 @@ def __init__( cross_attention_dim: Optional[int] = None, upcast_attention: bool = False, ) -> None: - """ - Unet's up block containing resnet, upsamplers, and self-attention blocks. - - Args: - spatial_dims: The number of spatial dimensions. - in_channels: number of input channels. - prev_output_channel: number of channels from residual connection. - out_channels: number of output channels. - temb_channels: number of timestep embedding channels. - num_res_blocks: number of residual blocks. - norm_num_groups: number of groups for the group normalization. - norm_eps: epsilon for the group normalization. - add_upsample: if True add downsample block. - resblock_updown: if True use residual blocks for upsampling. - num_head_channels: number of channels in each attention head. - transformer_num_layers: number of layers of Transformer blocks to use. - cross_attention_dim: number of context dimensions to use. - upcast_attention: if True, upcast attention operations to full precision. - """ super().__init__() self.resblock_updown = resblock_updown From 62deae0fe621a2271bfc08da5921116e85233d06 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Sat, 14 Jan 2023 22:43:54 +0000 Subject: [PATCH 3/3] Add tests Signed-off-by: Walter Hugo Lopez Pinaya --- tests/test_diffusion_model_unet.py | 71 +++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 15 deletions(-) diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 255e2fc9..a5f387d0 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -168,6 +168,56 @@ ], ] +COND_CASES_2D = [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + }, + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "resblock_updown": True, + }, + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 4, + "norm_num_groups": 8, + "with_conditioning": True, + "transformer_num_layers": 1, + "cross_attention_dim": 3, + "upcast_attention": True, + }, + ], +] + class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -340,21 +390,12 @@ def test_script_conditioned_2d_models(self): ) test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) - def test_script_conditioned_2d_models_with_resblock_updown(self): - net = DiffusionModelUNet( - spatial_dims=2, - in_channels=1, - out_channels=1, - num_res_blocks=1, - num_channels=(8, 8, 8), - attention_levels=(False, False, True), - norm_num_groups=8, - with_conditioning=True, - transformer_num_layers=1, - cross_attention_dim=3, - resblock_updown=True, - ) - test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + @parameterized.expand(COND_CASES_2D) + def test_conditioned_2d_models_shape(self, input_param): + net = DiffusionModelUNet(**input_param) + with eval_mode(net): + result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + self.assertEqual(result.shape, (1, 1, 16, 16)) class TestDiffusionModelUNet3D(unittest.TestCase):