From 34cf224d66e63fb092055dcd709c2d7d42c78de3 Mon Sep 17 00:00:00 2001 From: QishengL <89773749+QishengL@users.noreply.github.com> Date: Tue, 21 Mar 2023 19:26:46 +0400 Subject: [PATCH 1/4] Update unet_2d_blocks.py AttnUpBlock2D add upsample_size paramerter --- src/diffusers/models/unet_2d_blocks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3070351279b8..3a3d72d747d5 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1673,7 +1673,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None,,upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1685,7 +1685,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states,upsample_size) return hidden_states From 3faa2696fb819400c9cbab34f2fd4367d89c83ed Mon Sep 17 00:00:00 2001 From: QishengL <89773749+QishengL@users.noreply.github.com> Date: Tue, 21 Mar 2023 19:31:06 +0400 Subject: [PATCH 2/4] Update unet_2d_blocks.py AttnUpBlock2D add upsample_size paramerter --- src/diffusers/models/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3a3d72d747d5..38d586ecdf04 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1673,7 +1673,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None,,upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None,upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] From deddda4addfb7be0a392c3b8ec7f9298e02b4931 Mon Sep 17 00:00:00 2001 From: QishengL <89773749+QishengL@users.noreply.github.com> Date: Tue, 21 Mar 2023 21:43:45 +0400 Subject: [PATCH 3/4] Update unet_2d.py Add upsample_size --- src/diffusers/models/unet_2d.py | 56 ++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2df6e60d88c9..6779660b3e79 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -181,17 +181,30 @@ def __init__( resnet_groups=norm_num_groups, add_attention=add_attention, ) - + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - is_final_block = i == len(block_out_channels) - 1 - + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, @@ -199,7 +212,7 @@ def __init__( out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, - add_upsample=not is_final_block, + add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, @@ -235,6 +248,26 @@ def forward( [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + + + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + + + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -288,14 +321,21 @@ def forward( # 5. up skip_sample = None - for upsample_block in self.up_blocks: + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "skip_conv"): - sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) + sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample,upsample_size=upsample_size) else: - sample = upsample_block(sample, res_samples, emb) + sample = upsample_block(sample, res_samples, emb,upsample_size=upsample_size) # 6. post-process sample = self.conv_norm_out(sample) From 1b326b35bfae907287714f02620ce0f85a4fa9d8 Mon Sep 17 00:00:00 2001 From: QishengL <89773749+QishengL@users.noreply.github.com> Date: Tue, 21 Mar 2023 21:50:48 +0400 Subject: [PATCH 4/4] Update unet_2d.py --- src/diffusers/models/unet_2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 6779660b3e79..cc514d26781e 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -262,7 +262,6 @@ def forward( upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True