diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index 73bfa401932f..ab5c393518e2 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -250,10 +250,11 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) @@ -377,10 +378,11 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] output_states += (hidden_states,) @@ -590,10 +592,11 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] hidden_states = temp_attn( - hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index 9bc89c571c52..ee4d0d7cab98 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -526,8 +526,11 @@ def forward( sample = self.conv_in(sample) sample = self.transformer_in( - sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs - ).sample + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] # 3. down down_block_res_samples = (sample,) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index e30f183808a5..ecc330b5f504 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -648,7 +648,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index ce5109a58213..7a4b73cd3c35 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -723,7 +723,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: