From 0f1df93ab44122238a2b9de1e233f115f1aef07d Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Tue, 10 Jan 2023 11:48:10 +0000 Subject: [PATCH 1/4] Add residual blocks for down/upsampling Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 191 ++++++++++++++---- tests/test_diffusion_model_unet.py | 66 ++++++ 2 files changed, 221 insertions(+), 36 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 7aabfbba..059690f6 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -645,6 +645,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, + resblock_updown: bool = False, downsample_padding: int = 1, ) -> None: """ @@ -659,9 +660,12 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] for i in range(num_res_blocks): @@ -680,13 +684,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) else: self.downsampler = None @@ -701,7 +716,10 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) + if self.resblock_updown: + hidden_states = self.downsampler(hidden_states, temb) + else: + hidden_states = self.downsampler(hidden_states) output_states.append(hidden_states) return hidden_states, output_states @@ -718,6 +736,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, + resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, ) -> None: @@ -733,10 +752,13 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -766,13 +788,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) else: self.downsampler = None @@ -788,7 +821,10 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) + if self.resblock_updown: + hidden_states = self.downsampler(hidden_states, temb) + else: + hidden_states = self.downsampler(hidden_states) output_states.append(hidden_states) return hidden_states, output_states @@ -805,6 +841,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, + resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, transformer_num_layers: int = 1, @@ -822,12 +859,15 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -861,13 +901,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) else: self.downsampler = None @@ -882,7 +933,10 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) + if self.resblock_updown: + hidden_states = self.downsampler(hidden_states, temb) + else: + hidden_states = self.downsampler(hidden_states) output_states.append(hidden_states) return hidden_states, output_states @@ -1025,6 +1079,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, + resblock_updown: bool = False, ) -> None: """ Unet's up block containing resnet and upsamplers blocks. @@ -1039,8 +1094,10 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. """ super().__init__() + self.resblock_updown = resblock_updown resnets = [] for i in range(num_res_blocks): @@ -1061,9 +1118,20 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) else: self.upsampler = None @@ -1084,7 +1152,10 @@ def forward( hidden_states = resnet(hidden_states, temb) if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + if self.resblock_updown: + hidden_states = self.upsampler(hidden_states, temb) + else: + hidden_states = self.upsampler(hidden_states) return hidden_states @@ -1101,6 +1172,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, + resblock_updown: bool = False, num_head_channels: int = 1, ) -> None: """ @@ -1116,9 +1188,12 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -1150,9 +1225,20 @@ def __init__( self.attentions = nn.ModuleList(attentions) if add_upsample: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) else: self.upsampler = None @@ -1174,7 +1260,10 @@ def forward( hidden_states = attn(hidden_states) if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + if self.resblock_updown: + hidden_states = self.upsampler(hidden_states, temb) + else: + hidden_states = self.upsampler(hidden_states) return hidden_states @@ -1191,6 +1280,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, + resblock_updown: bool = False, num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, @@ -1208,11 +1298,14 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -1247,9 +1340,20 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) else: self.upsampler = None @@ -1270,7 +1374,10 @@ def forward( hidden_states = attn(hidden_states, context=context) if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + if self.resblock_updown: + hidden_states = self.upsampler(hidden_states, temb) + else: + hidden_states = self.upsampler(hidden_states) return hidden_states @@ -1284,6 +1391,7 @@ def get_down_block( norm_num_groups: int, norm_eps: float, add_downsample: bool, + resblock_updown: bool, with_attn: bool, with_cross_attn: bool, num_head_channels: int, @@ -1300,6 +1408,7 @@ def get_down_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, ) elif with_cross_attn: @@ -1312,6 +1421,7 @@ def get_down_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, @@ -1326,6 +1436,7 @@ def get_down_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, + resblock_updown=resblock_updown, ) @@ -1372,6 +1483,7 @@ def get_up_block( norm_num_groups: int, norm_eps: float, add_upsample: bool, + resblock_updown: bool, with_attn: bool, with_cross_attn: bool, num_head_channels: int, @@ -1389,6 +1501,7 @@ def get_up_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, ) elif with_cross_attn: @@ -1402,6 +1515,7 @@ def get_up_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, @@ -1417,6 +1531,7 @@ def get_up_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, + resblock_updown=resblock_updown, ) @@ -1435,6 +1550,7 @@ class DiffusionModelUNet(nn.Module): attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. num_head_channels: number of channels in each attention head. with_conditioning: if True add spatial transformers to perform conditioning. transformer_num_layers: number of layers of Transformer blocks to use. @@ -1453,6 +1569,7 @@ def __init__( attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, + resblock_updown: bool = False, num_head_channels: int = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, @@ -1524,6 +1641,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=not is_final_block, + resblock_updown=resblock_updown, with_attn=(attention_levels[i] and not with_conditioning), with_cross_attn=(attention_levels[i] and with_conditioning), num_head_channels=num_head_channels, @@ -1568,6 +1686,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=not is_final_block, + resblock_updown=resblock_updown, with_attn=(reversed_attention_levels[i] and not with_conditioning), with_cross_attn=(reversed_attention_levels[i] and with_conditioning), num_head_channels=num_head_channels, diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index f9925b96..e08467ef 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -30,6 +30,30 @@ "norm_num_groups": 8, }, ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + }, + ], [ { "spatial_dims": 2, @@ -40,6 +64,7 @@ "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, + "resblock_updown": True, }, ], [ @@ -68,6 +93,30 @@ "norm_num_groups": 8, }, ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + }, + ], [ { "spatial_dims": 3, @@ -78,6 +127,7 @@ "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, + "resblock_updown": True, }, ], [ @@ -233,6 +283,22 @@ def test_script_conditioned_2d_models(self): ) test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + def test_script_conditioned_2d_models_with_resblock_updown(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + resblock_updown=True, + ) + test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + class TestDiffusionModelUNet3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) From 374eccd7322728c8da5f761633d15dfa0f484021 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 13 Jan 2023 22:40:05 +0000 Subject: [PATCH 2/4] Add emb to Upsample and Downsample layers Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 672c323f..d8a56e20 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -472,7 +472,8 @@ def __init__( assert self.num_channels == self.out_channels self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: + del emb assert x.shape[1] == self.num_channels return self.op(x) @@ -513,7 +514,8 @@ def __init__( conv_only=True, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: + del emb assert x.shape[1] == self.num_channels x = F.interpolate(x, scale_factor=2.0, mode="nearest") if self.use_conv: @@ -933,10 +935,7 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - if self.resblock_updown: - hidden_states = self.downsampler(hidden_states, temb) - else: - hidden_states = self.downsampler(hidden_states) + hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states @@ -1152,10 +1151,7 @@ def forward( hidden_states = resnet(hidden_states, temb) if self.upsampler is not None: - if self.resblock_updown: - hidden_states = self.upsampler(hidden_states, temb) - else: - hidden_states = self.upsampler(hidden_states) + hidden_states = self.upsampler(hidden_states, temb) return hidden_states From bdef518a0a00a17a527f48b331b3e9114177df05 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 13 Jan 2023 22:43:30 +0000 Subject: [PATCH 3/4] Change calls #176 Signed-off-by: Walter Hugo Lopez Pinaya --- .../networks/nets/diffusion_model_unet.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index d8a56e20..a3784f4f 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -718,10 +718,7 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - if self.resblock_updown: - hidden_states = self.downsampler(hidden_states, temb) - else: - hidden_states = self.downsampler(hidden_states) + hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states @@ -823,10 +820,7 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - if self.resblock_updown: - hidden_states = self.downsampler(hidden_states, temb) - else: - hidden_states = self.downsampler(hidden_states) + hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states @@ -1256,10 +1250,7 @@ def forward( hidden_states = attn(hidden_states) if self.upsampler is not None: - if self.resblock_updown: - hidden_states = self.upsampler(hidden_states, temb) - else: - hidden_states = self.upsampler(hidden_states) + hidden_states = self.upsampler(hidden_states, temb) return hidden_states @@ -1370,10 +1361,7 @@ def forward( hidden_states = attn(hidden_states, context=context) if self.upsampler is not None: - if self.resblock_updown: - hidden_states = self.upsampler(hidden_states, temb) - else: - hidden_states = self.upsampler(hidden_states) + hidden_states = self.upsampler(hidden_states, temb) return hidden_states From 16989f66a46fbbff4616d80f61c9b95a8c73b114 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Fri, 13 Jan 2023 23:06:43 +0000 Subject: [PATCH 4/4] Try to fix torchscript error Signed-off-by: Walter Hugo Lopez Pinaya --- generative/networks/nets/diffusion_model_unet.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index a3784f4f..3812a7f9 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -513,6 +513,8 @@ def __init__( padding=padding, conv_only=True, ) + else: + self.conv = None def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: del emb