From 25cfcb1b003e6a389242586980ad4c46ccdef0f9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Jul 2023 08:55:51 +0530 Subject: [PATCH 1/5] use sample directly instead of the dataclass. --- src/diffusers/models/unet_3d_condition.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 9bc89c571c52..ee4d0d7cab98 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -526,8 +526,11 @@ def forward( sample = self.conv_in(sample) sample = self.transformer_in( - sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] # 3. down down_block_res_samples = (sample,) From 2cd1460437ff76fb38cf852f62fd25acb93906c7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Jul 2023 09:06:28 +0530 Subject: [PATCH 2/5] more usage of directly samples instead of dataclasses --- src/diffusers/models/unet_3d_blocks.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 73bfa401932f..2ad049f5b4e3 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -250,10 +250,11 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) @@ -377,10 +378,11 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] output_states += (hidden_states,) From a63443b9f0d6722611e859a5af6dadf92b38c55d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Jul 2023 09:12:53 +0530 Subject: [PATCH 3/5] more usage of directly samples instead of dataclasses --- src/diffusers/models/unet_3d_blocks.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 2ad049f5b4e3..e6fd8b4869f3 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -592,10 +592,11 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False + )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: From 5958b7ff4dfc68d74f737e1f9bba794868f3db3b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Jul 2023 09:15:50 +0530 Subject: [PATCH 4/5] use direct sample in the pipeline. --- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index e30f183808a5..315e3c95b403 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -648,7 +648,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False + )[0] # perform guidance if do_classifier_free_guidance: From dd5d0e495af90c9e9cb61df17a73c5ce988e209f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 5 Jul 2023 09:16:48 +0530 Subject: [PATCH 5/5] direct usage of sample in the img2img case. --- src/diffusers/models/unet_3d_blocks.py | 2 +- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 2 +- .../pipeline_text_to_video_synth_img2img.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index e6fd8b4869f3..ab5c393518e2 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -592,7 +592,7 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - return_dict=False + return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 315e3c95b403..ecc330b5f504 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -648,7 +648,7 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - return_dict=False + return_dict=False, )[0] # perform guidance diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index ce5109a58213..7a4b73cd3c35 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -723,7 +723,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: